From 2dcc98384fa37580e10088d56188442290e0971a Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Mon, 19 Oct 2020 13:26:54 -0700 Subject: [PATCH 01/12] regenerate library --- .kokoro/docs/common.cfg | 2 +- .kokoro/test-samples.sh | 8 +- docs/aiplatform_v1beta1/services.rst | 27 +- docs/aiplatform_v1beta1/types.rst | 4 +- docs/conf.py | 6 +- google/cloud/aiplatform_v1beta1/__init__.py | 339 +- .../services/dataset_service/__init__.py | 6 +- .../services/dataset_service/async_client.py | 1056 +++ .../services/dataset_service/client.py | 938 ++- .../services/dataset_service/pagers.py | 259 +- .../dataset_service/transports/__init__.py | 9 +- .../dataset_service/transports/base.py | 248 +- .../dataset_service/transports/grpc.py | 330 +- .../transports/grpc_asyncio.py | 499 ++ .../services/endpoint_service/__init__.py | 6 +- .../services/endpoint_service/async_client.py | 837 +++ .../services/endpoint_service/client.py | 749 +- .../services/endpoint_service/pagers.py | 88 +- .../endpoint_service/transports/__init__.py | 9 +- .../endpoint_service/transports/base.py | 203 +- .../endpoint_service/transports/grpc.py | 281 +- .../transports/grpc_asyncio.py | 423 ++ .../services/job_service/__init__.py | 6 +- .../services/job_service/async_client.py | 1905 +++++ .../services/job_service/client.py | 1587 +++-- .../services/job_service/pagers.py | 345 +- .../job_service/transports/__init__.py | 9 +- .../services/job_service/transports/base.py | 455 +- .../services/job_service/transports/grpc.py | 513 +- .../job_service/transports/grpc_asyncio.py | 811 +++ .../services/migration_service/__init__.py | 24 + .../migration_service/async_client.py | 357 + .../services/migration_service/client.py | 607 ++ .../services/migration_service/pagers.py | 143 + .../migration_service/transports/__init__.py | 36 + .../migration_service/transports/base.py | 148 + .../migration_service/transports/grpc.py | 292 + .../transports/grpc_asyncio.py | 297 + .../services/model_service/__init__.py | 6 +- .../services/model_service/async_client.py | 1044 +++ .../services/model_service/client.py | 938 ++- .../services/model_service/pagers.py | 259 +- .../model_service/transports/__init__.py | 9 +- .../services/model_service/transports/base.py | 252 +- .../services/model_service/transports/grpc.py | 330 +- .../model_service/transports/grpc_asyncio.py | 507 ++ .../services/pipeline_service/__init__.py | 6 +- .../services/pipeline_service/async_client.py | 598 ++ .../services/pipeline_service/client.py | 603 +- .../services/pipeline_service/pagers.py | 88 +- .../pipeline_service/transports/__init__.py | 9 +- .../pipeline_service/transports/base.py | 174 +- .../pipeline_service/transports/grpc.py | 266 +- .../transports/grpc_asyncio.py | 384 + .../services/prediction_service/__init__.py | 6 +- .../prediction_service/async_client.py | 380 + .../services/prediction_service/client.py | 425 +- .../prediction_service/transports/__init__.py | 9 +- .../prediction_service/transports/base.py | 113 +- .../prediction_service/transports/grpc.py | 215 +- .../transports/grpc_asyncio.py | 283 + .../specialist_pool_service/__init__.py | 6 +- .../specialist_pool_service/async_client.py | 636 ++ .../specialist_pool_service/client.py | 590 +- .../specialist_pool_service/pagers.py | 88 +- .../transports/__init__.py | 13 +- .../transports/base.py | 175 +- .../transports/grpc.py | 263 +- .../transports/grpc_asyncio.py | 375 + .../aiplatform_v1beta1/types/__init__.py | 528 +- .../types/accelerator_type.py | 5 +- .../aiplatform_v1beta1/types/annotation.py | 28 +- .../types/annotation_spec.py | 17 +- .../types/batch_prediction_job.py | 143 +- .../types/completion_stats.py | 7 +- .../aiplatform_v1beta1/types/custom_job.py | 90 +- .../aiplatform_v1beta1/types/data_item.py | 22 +- .../types/data_labeling_job.py | 75 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 38 +- .../types/dataset_service.py | 127 +- .../types/deployed_model_ref.py | 6 +- .../aiplatform_v1beta1/types/endpoint.py | 50 +- .../types/endpoint_service.py | 79 +- .../cloud/aiplatform_v1beta1/types/env_var.py | 6 +- .../aiplatform_v1beta1/types/explanation.py | 49 +- .../types/explanation_metadata.py | 27 +- .../types/hyperparameter_tuning_job.py | 55 +- google/cloud/aiplatform_v1beta1/types/io.py | 12 +- .../aiplatform_v1beta1/types/job_service.py | 135 +- .../aiplatform_v1beta1/types/job_state.py | 5 +- .../types/machine_resources.py | 35 +- .../types/manual_batch_tuning_parameters.py | 6 +- .../types/migratable_resource.py | 178 + .../types/migration_service.py | 308 + .../cloud/aiplatform_v1beta1/types/model.py | 365 +- .../types/model_evaluation.py | 22 +- .../types/model_evaluation_slice.py | 25 +- .../aiplatform_v1beta1/types/model_service.py | 120 +- .../aiplatform_v1beta1/types/operation.py | 25 +- .../types/pipeline_service.py | 36 +- .../types/pipeline_state.py | 5 +- .../types/prediction_service.py | 45 +- .../types/specialist_pool.py | 9 +- .../types/specialist_pool_service.py | 54 +- .../cloud/aiplatform_v1beta1/types/study.py | 100 +- .../types/training_pipeline.py | 125 +- .../types/user_action_reference.py | 11 +- mypy.ini | 2 +- noxfile.py | 25 +- scripts/fixup_aiplatform_v1beta1_keywords.py | 2 + setup.py | 52 +- synth.metadata | 4 +- synth.py | 4 +- .../unit/gapic/aiplatform_v1beta1/__init__.py | 1 + .../test_dataset_service.py | 3687 ++++++++++ .../test_endpoint_service.py | 2604 +++++++ .../aiplatform_v1beta1/test_job_service.py | 6289 +++++++++++++++++ .../test_migration_service.py | 1546 ++++ .../aiplatform_v1beta1/test_model_service.py | 3799 ++++++++++ .../test_pipeline_service.py | 2160 ++++++ .../test_prediction_service.py | 1220 ++++ .../test_specialist_pool_service.py | 2096 ++++++ tests/unit/gapic/test_dataset_service.py | 1140 --- tests/unit/gapic/test_endpoint_service.py | 771 -- tests/unit/gapic/test_job_service.py | 2118 ------ tests/unit/gapic/test_model_service.py | 1223 ---- tests/unit/gapic/test_pipeline_service.py | 675 -- tests/unit/gapic/test_prediction_service.py | 311 - .../gapic/test_specialist_pool_service.py | 681 -- 129 files changed, 44899 insertions(+), 11396 deletions(-) create mode 100644 google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/job_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/model_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/types/migratable_resource.py create mode 100644 google/cloud/aiplatform_v1beta1/types/migration_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/__init__.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_job_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_model_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py delete mode 100644 tests/unit/gapic/test_dataset_service.py delete mode 100644 tests/unit/gapic/test_endpoint_service.py delete mode 100644 tests/unit/gapic/test_job_service.py delete mode 100644 tests/unit/gapic/test_model_service.py delete mode 100644 tests/unit/gapic/test_pipeline_service.py delete mode 100644 tests/unit/gapic/test_prediction_service.py delete mode 100644 tests/unit/gapic/test_specialist_pool_service.py diff --git a/.kokoro/docs/common.cfg b/.kokoro/docs/common.cfg index b940b1d53f..5adc161f36 100644 --- a/.kokoro/docs/common.cfg +++ b/.kokoro/docs/common.cfg @@ -30,7 +30,7 @@ env_vars: { env_vars: { key: "V2_STAGING_BUCKET" - value: "docs-staging-v2-staging" + value: "docs-staging-v2" } # It will upload the docker image after successful builds. diff --git a/.kokoro/test-samples.sh b/.kokoro/test-samples.sh index 419c1a7e7d..aed13be6d4 100755 --- a/.kokoro/test-samples.sh +++ b/.kokoro/test-samples.sh @@ -28,6 +28,12 @@ if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then git checkout $LATEST_RELEASE fi +# Exit early if samples directory doesn't exist +if [ ! -d "./samples" ]; then + echo "No tests run. `./samples` not found" + exit 0 +fi + # Disable buffering, so that the logs stream through. export PYTHONUNBUFFERED=1 @@ -101,4 +107,4 @@ cd "$ROOT" # Workaround for Kokoro permissions issue: delete secrets rm testing/{test-env.sh,client-secrets.json,service-account.json} -exit "$RTN" \ No newline at end of file +exit "$RTN" diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index a23ba7675f..664c9df0a8 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -1,6 +1,27 @@ -Client for Google Cloud Aiplatform API -====================================== +Services for Google Cloud Aiplatform v1beta1 API +================================================ -.. automodule:: google.cloud.aiplatform_v1beta1 +.. automodule:: google.cloud.aiplatform_v1beta1.services.dataset_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.endpoint_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.job_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.migration_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.model_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.pipeline_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.prediction_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.specialist_pool_service :members: :inherited-members: diff --git a/docs/aiplatform_v1beta1/types.rst b/docs/aiplatform_v1beta1/types.rst index df8cb24970..3f8a7c9d65 100644 --- a/docs/aiplatform_v1beta1/types.rst +++ b/docs/aiplatform_v1beta1/types.rst @@ -1,5 +1,5 @@ -Types for Google Cloud Aiplatform API -===================================== +Types for Google Cloud Aiplatform v1beta1 API +============================================= .. automodule:: google.cloud.aiplatform_v1beta1.types :members: diff --git a/docs/conf.py b/docs/conf.py index 5ac4507dba..b45ecd8682 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -347,8 +347,12 @@ intersphinx_mapping = { "python": ("http://python.readthedocs.org/en/latest/", None), "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), - "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,), + "google.api_core": ( + "https://googleapis.dev/python/google-api-core/latest/", + None, + ), "grpc": ("https://grpc.io/grpc/python/", None), + } diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index b99a73164d..da76eaf689 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -15,10 +15,10 @@ # limitations under the License. # - from .services.dataset_service import DatasetServiceClient from .services.endpoint_service import EndpointServiceClient from .services.job_service import JobServiceClient +from .services.migration_service import MigrationServiceClient from .services.model_service import ModelServiceClient from .services.pipeline_service import PipelineServiceClient from .services.prediction_service import PredictionServiceClient @@ -121,6 +121,14 @@ from .types.machine_resources import MachineSpec from .types.machine_resources import ResourcesConsumed from .types.manual_batch_tuning_parameters import ManualBatchTuningParameters +from .types.migratable_resource import MigratableResource +from .types.migration_service import BatchMigrateResourcesOperationMetadata +from .types.migration_service import BatchMigrateResourcesRequest +from .types.migration_service import BatchMigrateResourcesResponse +from .types.migration_service import MigrateResourceRequest +from .types.migration_service import MigrateResourceResponse +from .types.migration_service import SearchMigratableResourcesRequest +from .types.migration_service import SearchMigratableResourcesResponse from .types.model import Model from .types.model import ModelContainerSpec from .types.model import Port @@ -179,164 +187,173 @@ __all__ = ( - "AcceleratorType", - "ActiveLearningConfig", - "Annotation", - "AnnotationSpec", - "Attribution", - "AutomaticResources", - "BatchDedicatedResources", - "BatchPredictionJob", - "BigQueryDestination", - "BigQuerySource", - "CancelBatchPredictionJobRequest", - "CancelCustomJobRequest", - "CancelDataLabelingJobRequest", - "CancelHyperparameterTuningJobRequest", - "CancelTrainingPipelineRequest", - "CompletionStats", - "ContainerRegistryDestination", - "ContainerSpec", - "CreateBatchPredictionJobRequest", - "CreateCustomJobRequest", - "CreateDataLabelingJobRequest", - "CreateDatasetOperationMetadata", - "CreateDatasetRequest", - "CreateEndpointOperationMetadata", - "CreateEndpointRequest", - "CreateHyperparameterTuningJobRequest", - "CreateSpecialistPoolOperationMetadata", - "CreateSpecialistPoolRequest", - "CreateTrainingPipelineRequest", - "CustomJob", - "CustomJobSpec", - "DataItem", - "DataLabelingJob", - "Dataset", - "DedicatedResources", - "DeleteBatchPredictionJobRequest", - "DeleteCustomJobRequest", - "DeleteDataLabelingJobRequest", - "DeleteDatasetRequest", - "DeleteEndpointRequest", - "DeleteHyperparameterTuningJobRequest", - "DeleteModelRequest", - "DeleteOperationMetadata", - "DeleteSpecialistPoolRequest", - "DeleteTrainingPipelineRequest", - "DeployModelOperationMetadata", - "DeployModelRequest", - "DeployModelResponse", - "DeployedModel", - "DeployedModelRef", - "Endpoint", - "EndpointServiceClient", - "EnvVar", - "ExplainRequest", - "ExplainResponse", - "Explanation", - "ExplanationMetadata", - "ExplanationParameters", - "ExplanationSpec", - "ExportDataConfig", - "ExportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportModelOperationMetadata", - "ExportModelRequest", - "ExportModelResponse", - "FilterSplit", - "FractionSplit", - "GcsDestination", - "GcsSource", - "GenericOperationMetadata", - "GetAnnotationSpecRequest", - "GetBatchPredictionJobRequest", - "GetCustomJobRequest", - "GetDataLabelingJobRequest", - "GetDatasetRequest", - "GetEndpointRequest", - "GetHyperparameterTuningJobRequest", - "GetModelEvaluationRequest", - "GetModelEvaluationSliceRequest", - "GetModelRequest", - "GetSpecialistPoolRequest", - "GetTrainingPipelineRequest", - "HyperparameterTuningJob", - "ImportDataConfig", - "ImportDataOperationMetadata", - "ImportDataRequest", - "ImportDataResponse", - "InputDataConfig", - "JobServiceClient", - "JobState", - "ListAnnotationsRequest", - "ListAnnotationsResponse", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "ListDataItemsRequest", - "ListDataItemsResponse", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "ListDatasetsRequest", - "ListDatasetsResponse", - "ListEndpointsRequest", - "ListEndpointsResponse", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "ListModelsRequest", - "ListModelsResponse", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "MachineSpec", - "ManualBatchTuningParameters", - "Measurement", - "Model", - "ModelContainerSpec", - "ModelEvaluation", - "ModelEvaluationSlice", - "ModelExplanation", - "ModelServiceClient", - "PipelineServiceClient", - "PipelineState", - "Port", - "PredefinedSplit", - "PredictRequest", - "PredictResponse", - "PredictSchemata", - "PredictionServiceClient", - "PythonPackageSpec", - "ResourcesConsumed", - "SampleConfig", - "SampledShapleyAttribution", - "Scheduling", - "SpecialistPool", - "SpecialistPoolServiceClient", - "StudySpec", - "TimestampSplit", - "TrainingConfig", - "TrainingPipeline", - "Trial", - "UndeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UpdateDatasetRequest", - "UpdateEndpointRequest", - "UpdateModelRequest", - "UpdateSpecialistPoolOperationMetadata", - "UpdateSpecialistPoolRequest", - "UploadModelOperationMetadata", - "UploadModelRequest", - "UploadModelResponse", - "UserActionReference", - "WorkerPoolSpec", - "DatasetServiceClient", + 'AcceleratorType', + 'ActiveLearningConfig', + 'Annotation', + 'AnnotationSpec', + 'Attribution', + 'AutomaticResources', + 'BatchDedicatedResources', + 'BatchMigrateResourcesOperationMetadata', + 'BatchMigrateResourcesRequest', + 'BatchMigrateResourcesResponse', + 'BatchPredictionJob', + 'BigQueryDestination', + 'BigQuerySource', + 'CancelBatchPredictionJobRequest', + 'CancelCustomJobRequest', + 'CancelDataLabelingJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CancelTrainingPipelineRequest', + 'CompletionStats', + 'ContainerRegistryDestination', + 'ContainerSpec', + 'CreateBatchPredictionJobRequest', + 'CreateCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'CreateDatasetOperationMetadata', + 'CreateDatasetRequest', + 'CreateEndpointOperationMetadata', + 'CreateEndpointRequest', + 'CreateHyperparameterTuningJobRequest', + 'CreateSpecialistPoolOperationMetadata', + 'CreateSpecialistPoolRequest', + 'CreateTrainingPipelineRequest', + 'CustomJob', + 'CustomJobSpec', + 'DataItem', + 'DataLabelingJob', + 'Dataset', + 'DedicatedResources', + 'DeleteBatchPredictionJobRequest', + 'DeleteCustomJobRequest', + 'DeleteDataLabelingJobRequest', + 'DeleteDatasetRequest', + 'DeleteEndpointRequest', + 'DeleteHyperparameterTuningJobRequest', + 'DeleteModelRequest', + 'DeleteOperationMetadata', + 'DeleteSpecialistPoolRequest', + 'DeleteTrainingPipelineRequest', + 'DeployModelOperationMetadata', + 'DeployModelRequest', + 'DeployModelResponse', + 'DeployedModel', + 'DeployedModelRef', + 'Endpoint', + 'EndpointServiceClient', + 'EnvVar', + 'ExplainRequest', + 'ExplainResponse', + 'Explanation', + 'ExplanationMetadata', + 'ExplanationParameters', + 'ExplanationSpec', + 'ExportDataConfig', + 'ExportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'ExportModelOperationMetadata', + 'ExportModelRequest', + 'ExportModelResponse', + 'FilterSplit', + 'FractionSplit', + 'GcsDestination', + 'GcsSource', + 'GenericOperationMetadata', + 'GetAnnotationSpecRequest', + 'GetBatchPredictionJobRequest', + 'GetCustomJobRequest', + 'GetDataLabelingJobRequest', + 'GetDatasetRequest', + 'GetEndpointRequest', + 'GetHyperparameterTuningJobRequest', + 'GetModelEvaluationRequest', + 'GetModelEvaluationSliceRequest', + 'GetModelRequest', + 'GetSpecialistPoolRequest', + 'GetTrainingPipelineRequest', + 'HyperparameterTuningJob', + 'ImportDataConfig', + 'ImportDataOperationMetadata', + 'ImportDataRequest', + 'ImportDataResponse', + 'InputDataConfig', + 'JobServiceClient', + 'JobState', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'ListModelsRequest', + 'ListModelsResponse', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'MachineSpec', + 'ManualBatchTuningParameters', + 'Measurement', + 'MigratableResource', + 'MigrateResourceRequest', + 'MigrateResourceResponse', + 'MigrationServiceClient', + 'Model', + 'ModelContainerSpec', + 'ModelEvaluation', + 'ModelEvaluationSlice', + 'ModelExplanation', + 'ModelServiceClient', + 'PipelineServiceClient', + 'PipelineState', + 'Port', + 'PredefinedSplit', + 'PredictRequest', + 'PredictResponse', + 'PredictSchemata', + 'PredictionServiceClient', + 'PythonPackageSpec', + 'ResourcesConsumed', + 'SampleConfig', + 'SampledShapleyAttribution', + 'Scheduling', + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'SpecialistPool', + 'SpecialistPoolServiceClient', + 'StudySpec', + 'TimestampSplit', + 'TrainingConfig', + 'TrainingPipeline', + 'Trial', + 'UndeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UpdateDatasetRequest', + 'UpdateEndpointRequest', + 'UpdateModelRequest', + 'UpdateSpecialistPoolOperationMetadata', + 'UpdateSpecialistPoolRequest', + 'UploadModelOperationMetadata', + 'UploadModelRequest', + 'UploadModelResponse', + 'UserActionReference', + 'WorkerPoolSpec', +'DatasetServiceClient', ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py index 8b973db167..9d1f004f6a 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py @@ -16,5 +16,9 @@ # from .client import DatasetServiceClient +from .async_client import DatasetServiceAsyncClient -__all__ = ("DatasetServiceClient",) +__all__ = ( + 'DatasetServiceClient', + 'DatasetServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py new file mode 100644 index 0000000000..8b67c83c6a --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -0,0 +1,1056 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers +from google.cloud.aiplatform_v1beta1.types import annotation +from google.cloud.aiplatform_v1beta1.types import annotation_spec +from google.cloud.aiplatform_v1beta1.types import data_item +from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport +from .client import DatasetServiceClient + + +class DatasetServiceAsyncClient: + """""" + + _client: DatasetServiceClient + + DEFAULT_ENDPOINT = DatasetServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = DatasetServiceClient.DEFAULT_MTLS_ENDPOINT + + annotation_path = staticmethod(DatasetServiceClient.annotation_path) + parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) + annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) + parse_annotation_spec_path = staticmethod(DatasetServiceClient.parse_annotation_spec_path) + data_item_path = staticmethod(DatasetServiceClient.data_item_path) + parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) + dataset_path = staticmethod(DatasetServiceClient.dataset_path) + parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) + + common_billing_account_path = staticmethod(DatasetServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(DatasetServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(DatasetServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(DatasetServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(DatasetServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(DatasetServiceClient.common_project_path) + parse_common_project_path = staticmethod(DatasetServiceClient.parse_common_project_path) + + common_location_path = staticmethod(DatasetServiceClient.common_location_path) + parse_common_location_path = staticmethod(DatasetServiceClient.parse_common_location_path) + + from_service_account_file = DatasetServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> DatasetServiceTransport: + """Return the transport used by the client instance. + + Returns: + DatasetServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the dataset service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.DatasetServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = DatasetServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_dataset(self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a Dataset. + + Args: + request (:class:`~.dataset_service.CreateDatasetRequest`): + The request object. Request message for + ``DatasetService.CreateDataset``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Dataset in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + dataset (:class:`~.gca_dataset.Dataset`): + Required. The Dataset to create. + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_dataset.Dataset`: A collection of + DataItems and Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, dataset]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.CreateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if dataset is not None: + request.dataset = dataset + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_dataset.Dataset, + metadata_type=dataset_service.CreateDatasetOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_dataset(self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: + r"""Gets a Dataset. + + Args: + request (:class:`~.dataset_service.GetDatasetRequest`): + The request object. Request message for + ``DatasetService.GetDataset``. + name (:class:`str`): + Required. The name of the Dataset + resource. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.dataset.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.GetDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_dataset(self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: + r"""Updates a Dataset. + + Args: + request (:class:`~.dataset_service.UpdateDatasetRequest`): + The request object. Request message for + ``DatasetService.UpdateDataset``. + dataset (:class:`~.gca_dataset.Dataset`): + Required. The Dataset which replaces + the resource on the server. + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + + [FieldMask](https://tinyurl.com/dev-google-protobuf#google.protobuf.FieldMask). + Updatable fields: + + - ``display_name`` + - ``description`` + - ``labels`` + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_dataset.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([dataset, update_mask]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.UpdateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if dataset is not None: + request.dataset = dataset + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('dataset.name', request.dataset.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_datasets(self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: + r"""Lists Datasets in a Location. + + Args: + request (:class:`~.dataset_service.ListDatasetsRequest`): + The request object. Request message for + ``DatasetService.ListDatasets``. + parent (:class:`str`): + Required. The name of the Dataset's parent resource. + Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListDatasetsAsyncPager: + Response message for + ``DatasetService.ListDatasets``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.ListDatasetsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_datasets, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDatasetsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_dataset(self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Dataset. + + Args: + request (:class:`~.dataset_service.DeleteDatasetRequest`): + The request object. Request message for + ``DatasetService.DeleteDataset``. + name (:class:`str`): + Required. The resource name of the Dataset to delete. + Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.DeleteDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def import_data(self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Imports data into a Dataset. + + Args: + request (:class:`~.dataset_service.ImportDataRequest`): + The request object. Request message for + ``DatasetService.ImportData``. + name (:class:`str`): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + import_configs (:class:`Sequence[~.dataset.ImportDataConfig]`): + Required. The desired input + locations. The contents of all input + locations will be imported in one batch. + This corresponds to the ``import_configs`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.dataset_service.ImportDataResponse`: + Response message for + ``DatasetService.ImportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name, import_configs]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.ImportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if import_configs is not None: + request.import_configs = import_configs + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.import_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + dataset_service.ImportDataResponse, + metadata_type=dataset_service.ImportDataOperationMetadata, + ) + + # Done; return the response. + return response + + async def export_data(self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Exports data from a Dataset. + + Args: + request (:class:`~.dataset_service.ExportDataRequest`): + The request object. Request message for + ``DatasetService.ExportData``. + name (:class:`str`): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + export_config (:class:`~.dataset.ExportDataConfig`): + Required. The desired output + location. + This corresponds to the ``export_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.dataset_service.ExportDataResponse`: + Response message for + ``DatasetService.ExportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name, export_config]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.ExportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if export_config is not None: + request.export_config = export_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + dataset_service.ExportDataResponse, + metadata_type=dataset_service.ExportDataOperationMetadata, + ) + + # Done; return the response. + return response + + async def list_data_items(self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: + r"""Lists DataItems in a Dataset. + + Args: + request (:class:`~.dataset_service.ListDataItemsRequest`): + The request object. Request message for + ``DatasetService.ListDataItems``. + parent (:class:`str`): + Required. The resource name of the Dataset to list + DataItems from. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListDataItemsAsyncPager: + Response message for + ``DatasetService.ListDataItems``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.ListDataItemsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_data_items, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDataItemsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_annotation_spec(self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: + r"""Gets an AnnotationSpec. + + Args: + request (:class:`~.dataset_service.GetAnnotationSpecRequest`): + The request object. Request message for + ``DatasetService.GetAnnotationSpec``. + name (:class:`str`): + Required. The name of the AnnotationSpec resource. + Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.annotation_spec.AnnotationSpec: + Identifies a concept with which + DataItems may be annotated with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.GetAnnotationSpecRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_annotation_spec, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_annotations(self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: + r"""Lists Annotations belongs to a dataitem + + Args: + request (:class:`~.dataset_service.ListAnnotationsRequest`): + The request object. Request message for + ``DatasetService.ListAnnotations``. + parent (:class:`str`): + Required. The resource name of the DataItem to list + Annotations from. Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListAnnotationsAsyncPager: + Response message for + ``DatasetService.ListAnnotations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = dataset_service.ListAnnotationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_annotations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListAnnotationsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'DatasetServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 99dcfbc5b6..b60c70f7c9 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import annotation_spec @@ -40,8 +47,9 @@ from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import DatasetServiceTransport +from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import DatasetServiceGrpcTransport +from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport class DatasetServiceClientMeta(type): @@ -51,13 +59,13 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry['grpc'] = DatasetServiceGrpcTransport + _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry["grpc"] = DatasetServiceGrpcTransport - - def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -79,8 +87,38 @@ def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport class DatasetServiceClient(metaclass=DatasetServiceClientMeta): """""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -97,26 +135,127 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file + @property + def transport(self) -> DatasetServiceTransport: + """Return the transport used by the client instance. + + Returns: + DatasetServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: + """Return a fully-qualified annotation string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + + @staticmethod + def parse_annotation_path(path: str) -> Dict[str,str]: + """Parse a annotation path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: + """Return a fully-qualified annotation_spec string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + + @staticmethod + def parse_annotation_spec_path(path: str) -> Dict[str,str]: + """Parse a annotation_spec path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: + """Return a fully-qualified data_item string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + + @staticmethod + def parse_data_item_path(path: str) -> Dict[str,str]: + """Parse a data_item path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) + return m.groupdict() if m else {} + @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str,str]: + """Parse a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, - ) -> None: + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -128,38 +267,107 @@ def __init__( transport (Union[str, ~.DatasetServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, DatasetServiceTransport): - if credentials: + # transport is a DatasetServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: raise ValueError( "When providing a transport instance, " - "provide its credentials directly." + "provide its scopes directly." ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) - def create_dataset( - self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_dataset(self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a Dataset. Args: @@ -197,32 +405,45 @@ def create_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, dataset]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.CreateDatasetRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if dataset is not None: - request.dataset = dataset + has_flattened_params = any([parent, dataset]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.CreateDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.CreateDatasetRequest): + request = dataset_service.CreateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if dataset is not None: + request.dataset = dataset # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_dataset, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -235,15 +456,14 @@ def create_dataset( # Done; return the response. return response - def get_dataset( - self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset(self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -272,48 +492,56 @@ def get_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = dataset_service.GetDatasetRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.GetDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.GetDatasetRequest): + request = dataset_service.GetDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_dataset, default_timeout=None, client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_dataset] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def update_dataset( - self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset(self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -355,45 +583,57 @@ def update_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([dataset, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.UpdateDatasetRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if dataset is not None: - request.dataset = dataset - if update_mask is not None: - request.update_mask = update_mask + has_flattened_params = any([dataset, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.UpdateDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.UpdateDatasetRequest): + request = dataset_service.UpdateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if dataset is not None: + request.dataset = dataset + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_dataset, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('dataset.name', request.dataset.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_datasets( - self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets(self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -425,55 +665,64 @@ def list_datasets( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = dataset_service.ListDatasetsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListDatasetsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListDatasetsRequest): + request = dataset_service.ListDatasetsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_datasets, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_datasets] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_dataset( - self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_dataset(self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Dataset. Args: @@ -518,30 +767,43 @@ def delete_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = dataset_service.DeleteDatasetRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.DeleteDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.DeleteDatasetRequest): + request = dataset_service.DeleteDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_dataset, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -554,16 +816,15 @@ def delete_dataset( # Done; return the response. return response - def import_data( - self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def import_data(self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Imports data into a Dataset. Args: @@ -603,30 +864,46 @@ def import_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, import_configs]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name, import_configs]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ImportDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ImportDataRequest): + request = dataset_service.ImportDataRequest(request) - request = dataset_service.ImportDataRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. - # If we have keyword arguments corresponding to fields on the - # request, apply these. + if name is not None: + request.name = name - if name is not None: - request.name = name - if import_configs is not None: - request.import_configs = import_configs + if import_configs: + request.import_configs.extend(import_configs) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.import_data, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.import_data] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -639,16 +916,15 @@ def import_data( # Done; return the response. return response - def export_data( - self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_data(self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports data from a Dataset. Args: @@ -687,30 +963,45 @@ def export_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, export_config]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.ExportDataRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - if export_config is not None: - request.export_config = export_config + has_flattened_params = any([name, export_config]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ExportDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ExportDataRequest): + request = dataset_service.ExportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if export_config is not None: + request.export_config = export_config # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.export_data, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.export_data] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -723,15 +1014,14 @@ def export_data( # Done; return the response. return response - def list_data_items( - self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items(self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -764,55 +1054,64 @@ def list_data_items( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = dataset_service.ListDataItemsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListDataItemsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListDataItemsRequest): + request = dataset_service.ListDataItemsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_data_items, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_data_items] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec( - self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec(self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -843,49 +1142,55 @@ def get_annotation_spec( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = dataset_service.GetAnnotationSpecRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.GetAnnotationSpecRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.GetAnnotationSpecRequest): + request = dataset_service.GetAnnotationSpecRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_annotation_spec, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_annotation_spec] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_annotations( - self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations(self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -919,55 +1224,72 @@ def list_annotations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = dataset_service.ListAnnotationsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListAnnotationsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListAnnotationsRequest): + request = dataset_service.ListAnnotationsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_annotations, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_annotations] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("DatasetServiceClient",) +__all__ = ( + 'DatasetServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index 0dd2e668cc..af29515c1d 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import data_item @@ -40,15 +40,12 @@ class ListDatasetsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse - ], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - ): + def __init__(self, + method: Callable[..., dataset_service.ListDatasetsResponse], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -58,10 +55,13 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListDatasetsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -71,7 +71,7 @@ def pages(self) -> Iterable[dataset_service.ListDatasetsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[dataset.Dataset]: @@ -79,7 +79,70 @@ def __iter__(self) -> Iterable[dataset.Dataset]: yield from page.datasets def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListDatasetsAsyncPager: + """A pager for iterating through ``list_datasets`` requests. + + This class thinly wraps an initial + :class:`~.dataset_service.ListDatasetsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``datasets`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDatasets`` requests and continue to iterate + through the ``datasets`` field on the + corresponding responses. + + All the usual :class:`~.dataset_service.ListDatasetsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.dataset_service.ListDatasetsRequest`): + The initial request object. + response (:class:`~.dataset_service.ListDatasetsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDatasetsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListDatasetsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[dataset.Dataset]: + async def async_generator(): + async for page in self.pages: + for response in page.datasets: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataItemsPager: @@ -99,16 +162,12 @@ class ListDataItemsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [dataset_service.ListDataItemsRequest], - dataset_service.ListDataItemsResponse, - ], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - ): + def __init__(self, + method: Callable[..., dataset_service.ListDataItemsResponse], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -118,10 +177,13 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListDataItemsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDataItemsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -131,7 +193,7 @@ def pages(self) -> Iterable[dataset_service.ListDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[data_item.DataItem]: @@ -139,7 +201,70 @@ def __iter__(self) -> Iterable[data_item.DataItem]: yield from page.data_items def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListDataItemsAsyncPager: + """A pager for iterating through ``list_data_items`` requests. + + This class thinly wraps an initial + :class:`~.dataset_service.ListDataItemsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``data_items`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDataItems`` requests and continue to iterate + through the ``data_items`` field on the + corresponding responses. + + All the usual :class:`~.dataset_service.ListDataItemsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.dataset_service.ListDataItemsRequest`): + The initial request object. + response (:class:`~.dataset_service.ListDataItemsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDataItemsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListDataItemsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[data_item.DataItem]: + async def async_generator(): + async for page in self.pages: + for response in page.data_items: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListAnnotationsPager: @@ -159,16 +284,12 @@ class ListAnnotationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse, - ], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - ): + def __init__(self, + method: Callable[..., dataset_service.ListAnnotationsResponse], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -178,10 +299,13 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListAnnotationsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListAnnotationsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -191,7 +315,7 @@ def pages(self) -> Iterable[dataset_service.ListAnnotationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[annotation.Annotation]: @@ -199,4 +323,67 @@ def __iter__(self) -> Iterable[annotation.Annotation]: yield from page.annotations def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListAnnotationsAsyncPager: + """A pager for iterating through ``list_annotations`` requests. + + This class thinly wraps an initial + :class:`~.dataset_service.ListAnnotationsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``annotations`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListAnnotations`` requests and continue to iterate + through the ``annotations`` field on the + corresponding responses. + + All the usual :class:`~.dataset_service.ListAnnotationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.dataset_service.ListAnnotationsRequest`): + The initial request object. + response (:class:`~.dataset_service.ListAnnotationsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListAnnotationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListAnnotationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[annotation.Annotation]: + async def async_generator(): + async for page in self.pages: + for response in page.annotations: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py index 7f1cb8ca21..fd4e511640 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import DatasetServiceTransport from .grpc import DatasetServiceGrpcTransport +from .grpc_asyncio import DatasetServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] -_transport_registry["grpc"] = DatasetServiceGrpcTransport +_transport_registry['grpc'] = DatasetServiceGrpcTransport +_transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport __all__ = ( - "DatasetServiceTransport", - "DatasetServiceGrpcTransport", + 'DatasetServiceTransport', + 'DatasetServiceGrpcTransport', + 'DatasetServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index f00538959f..1fa9766314 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -29,17 +33,32 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class DatasetServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -49,93 +68,196 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_dataset: gapic_v1.method.wrap_method( + self.create_dataset, + default_timeout=None, + client_info=client_info, + ), + self.get_dataset: gapic_v1.method.wrap_method( + self.get_dataset, + default_timeout=None, + client_info=client_info, + ), + self.update_dataset: gapic_v1.method.wrap_method( + self.update_dataset, + default_timeout=None, + client_info=client_info, + ), + self.list_datasets: gapic_v1.method.wrap_method( + self.list_datasets, + default_timeout=None, + client_info=client_info, + ), + self.delete_dataset: gapic_v1.method.wrap_method( + self.delete_dataset, + default_timeout=None, + client_info=client_info, + ), + self.import_data: gapic_v1.method.wrap_method( + self.import_data, + default_timeout=None, + client_info=client_info, + ), + self.export_data: gapic_v1.method.wrap_method( + self.export_data, + default_timeout=None, + client_info=client_info, + ), + self.list_data_items: gapic_v1.method.wrap_method( + self.list_data_items, + default_timeout=None, + client_info=client_info, + ), + self.get_annotation_spec: gapic_v1.method.wrap_method( + self.get_annotation_spec, + default_timeout=None, + client_info=client_info, + ), + self.list_annotations: gapic_v1.method.wrap_method( + self.list_annotations, + default_timeout=None, + client_info=client_info, + ), + + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property - def create_dataset( - self, - ) -> typing.Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: - raise NotImplementedError + def create_dataset(self) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def get_dataset( - self, - ) -> typing.Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: - raise NotImplementedError + def get_dataset(self) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[ + dataset.Dataset, + typing.Awaitable[dataset.Dataset] + ]]: + raise NotImplementedError() @property - def update_dataset( - self, - ) -> typing.Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: - raise NotImplementedError + def update_dataset(self) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[ + gca_dataset.Dataset, + typing.Awaitable[gca_dataset.Dataset] + ]]: + raise NotImplementedError() @property - def list_datasets( - self, - ) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse - ]: - raise NotImplementedError + def list_datasets(self) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse] + ]]: + raise NotImplementedError() @property - def delete_dataset( - self, - ) -> typing.Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: - raise NotImplementedError + def delete_dataset(self) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def import_data( - self, - ) -> typing.Callable[[dataset_service.ImportDataRequest], operations.Operation]: - raise NotImplementedError + def import_data(self) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def export_data( - self, - ) -> typing.Callable[[dataset_service.ExportDataRequest], operations.Operation]: - raise NotImplementedError + def export_data(self) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def list_data_items( - self, - ) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse - ]: - raise NotImplementedError + def list_data_items(self) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse] + ]]: + raise NotImplementedError() @property - def get_annotation_spec( - self, - ) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec - ]: - raise NotImplementedError + def get_annotation_spec(self) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec] + ]]: + raise NotImplementedError() @property - def list_annotations( - self, - ) -> typing.Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse, - ]: - raise NotImplementedError + def list_annotations(self) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse] + ]]: + raise NotImplementedError() -__all__ = ("DatasetServiceTransport",) +__all__ = ( + 'DatasetServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 8110a97c2c..3914e0d35b 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -29,7 +33,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import DatasetServiceTransport +from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO class DatasetServiceGrpcTransport(DatasetServiceTransport): @@ -42,14 +46,20 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - channel: grpc.Channel = None - ) -> None: + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -60,29 +70,107 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - **kwargs - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -91,30 +179,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property @@ -125,18 +220,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self.__dict__['operations_client'] @property - def create_dataset( - self, - ) -> Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: + def create_dataset(self) -> Callable[ + [dataset_service.CreateDatasetRequest], + operations.Operation]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -151,18 +246,18 @@ def create_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_dataset" not in self._stubs: - self._stubs["create_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", + if 'create_dataset' not in self._stubs: + self._stubs['create_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_dataset"] + return self._stubs['create_dataset'] @property - def get_dataset( - self, - ) -> Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: + def get_dataset(self) -> Callable[ + [dataset_service.GetDatasetRequest], + dataset.Dataset]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -177,18 +272,18 @@ def get_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_dataset" not in self._stubs: - self._stubs["get_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", + if 'get_dataset' not in self._stubs: + self._stubs['get_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs["get_dataset"] + return self._stubs['get_dataset'] @property - def update_dataset( - self, - ) -> Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: + def update_dataset(self) -> Callable[ + [dataset_service.UpdateDatasetRequest], + gca_dataset.Dataset]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -203,20 +298,18 @@ def update_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_dataset" not in self._stubs: - self._stubs["update_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", + if 'update_dataset' not in self._stubs: + self._stubs['update_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs["update_dataset"] + return self._stubs['update_dataset'] @property - def list_datasets( - self, - ) -> Callable[ - [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse - ]: + def list_datasets(self) -> Callable[ + [dataset_service.ListDatasetsRequest], + dataset_service.ListDatasetsResponse]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -231,18 +324,18 @@ def list_datasets( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_datasets" not in self._stubs: - self._stubs["list_datasets"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", + if 'list_datasets' not in self._stubs: + self._stubs['list_datasets'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs["list_datasets"] + return self._stubs['list_datasets'] @property - def delete_dataset( - self, - ) -> Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: + def delete_dataset(self) -> Callable[ + [dataset_service.DeleteDatasetRequest], + operations.Operation]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -257,18 +350,18 @@ def delete_dataset( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_dataset" not in self._stubs: - self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", + if 'delete_dataset' not in self._stubs: + self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_dataset"] + return self._stubs['delete_dataset'] @property - def import_data( - self, - ) -> Callable[[dataset_service.ImportDataRequest], operations.Operation]: + def import_data(self) -> Callable[ + [dataset_service.ImportDataRequest], + operations.Operation]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -283,18 +376,18 @@ def import_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "import_data" not in self._stubs: - self._stubs["import_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", + if 'import_data' not in self._stubs: + self._stubs['import_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["import_data"] + return self._stubs['import_data'] @property - def export_data( - self, - ) -> Callable[[dataset_service.ExportDataRequest], operations.Operation]: + def export_data(self) -> Callable[ + [dataset_service.ExportDataRequest], + operations.Operation]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -309,20 +402,18 @@ def export_data( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_data" not in self._stubs: - self._stubs["export_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", + if 'export_data' not in self._stubs: + self._stubs['export_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_data"] + return self._stubs['export_data'] @property - def list_data_items( - self, - ) -> Callable[ - [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse - ]: + def list_data_items(self) -> Callable[ + [dataset_service.ListDataItemsRequest], + dataset_service.ListDataItemsResponse]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -337,20 +428,18 @@ def list_data_items( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_items" not in self._stubs: - self._stubs["list_data_items"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", + if 'list_data_items' not in self._stubs: + self._stubs['list_data_items'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs["list_data_items"] + return self._stubs['list_data_items'] @property - def get_annotation_spec( - self, - ) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec - ]: + def get_annotation_spec(self) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + annotation_spec.AnnotationSpec]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -365,21 +454,18 @@ def get_annotation_spec( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_annotation_spec" not in self._stubs: - self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", + if 'get_annotation_spec' not in self._stubs: + self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs["get_annotation_spec"] + return self._stubs['get_annotation_spec'] @property - def list_annotations( - self, - ) -> Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse, - ]: + def list_annotations(self) -> Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -394,13 +480,15 @@ def list_annotations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_annotations" not in self._stubs: - self._stubs["list_annotations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", + if 'list_annotations' not in self._stubs: + self._stubs['list_annotations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs["list_annotations"] + return self._stubs['list_annotations'] -__all__ = ("DatasetServiceGrpcTransport",) +__all__ = ( + 'DatasetServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..c8d51ca917 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -0,0 +1,499 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import annotation_spec +from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import DatasetServiceGrpcTransport + + +class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): + """gRPC AsyncIO backend transport for DatasetService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def create_dataset(self) -> Callable[ + [dataset_service.CreateDatasetRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create dataset method over gRPC. + + Creates a Dataset. + + Returns: + Callable[[~.CreateDatasetRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_dataset' not in self._stubs: + self._stubs['create_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', + request_serializer=dataset_service.CreateDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_dataset'] + + @property + def get_dataset(self) -> Callable[ + [dataset_service.GetDatasetRequest], + Awaitable[dataset.Dataset]]: + r"""Return a callable for the get dataset method over gRPC. + + Gets a Dataset. + + Returns: + Callable[[~.GetDatasetRequest], + Awaitable[~.Dataset]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_dataset' not in self._stubs: + self._stubs['get_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', + request_serializer=dataset_service.GetDatasetRequest.serialize, + response_deserializer=dataset.Dataset.deserialize, + ) + return self._stubs['get_dataset'] + + @property + def update_dataset(self) -> Callable[ + [dataset_service.UpdateDatasetRequest], + Awaitable[gca_dataset.Dataset]]: + r"""Return a callable for the update dataset method over gRPC. + + Updates a Dataset. + + Returns: + Callable[[~.UpdateDatasetRequest], + Awaitable[~.Dataset]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_dataset' not in self._stubs: + self._stubs['update_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', + request_serializer=dataset_service.UpdateDatasetRequest.serialize, + response_deserializer=gca_dataset.Dataset.deserialize, + ) + return self._stubs['update_dataset'] + + @property + def list_datasets(self) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse]]: + r"""Return a callable for the list datasets method over gRPC. + + Lists Datasets in a Location. + + Returns: + Callable[[~.ListDatasetsRequest], + Awaitable[~.ListDatasetsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_datasets' not in self._stubs: + self._stubs['list_datasets'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', + request_serializer=dataset_service.ListDatasetsRequest.serialize, + response_deserializer=dataset_service.ListDatasetsResponse.deserialize, + ) + return self._stubs['list_datasets'] + + @property + def delete_dataset(self) -> Callable[ + [dataset_service.DeleteDatasetRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete dataset method over gRPC. + + Deletes a Dataset. + + Returns: + Callable[[~.DeleteDatasetRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_dataset' not in self._stubs: + self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', + request_serializer=dataset_service.DeleteDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_dataset'] + + @property + def import_data(self) -> Callable[ + [dataset_service.ImportDataRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the import data method over gRPC. + + Imports data into a Dataset. + + Returns: + Callable[[~.ImportDataRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'import_data' not in self._stubs: + self._stubs['import_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', + request_serializer=dataset_service.ImportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['import_data'] + + @property + def export_data(self) -> Callable[ + [dataset_service.ExportDataRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the export data method over gRPC. + + Exports data from a Dataset. + + Returns: + Callable[[~.ExportDataRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'export_data' not in self._stubs: + self._stubs['export_data'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', + request_serializer=dataset_service.ExportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['export_data'] + + @property + def list_data_items(self) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse]]: + r"""Return a callable for the list data items method over gRPC. + + Lists DataItems in a Dataset. + + Returns: + Callable[[~.ListDataItemsRequest], + Awaitable[~.ListDataItemsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_data_items' not in self._stubs: + self._stubs['list_data_items'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', + request_serializer=dataset_service.ListDataItemsRequest.serialize, + response_deserializer=dataset_service.ListDataItemsResponse.deserialize, + ) + return self._stubs['list_data_items'] + + @property + def get_annotation_spec(self) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec]]: + r"""Return a callable for the get annotation spec method over gRPC. + + Gets an AnnotationSpec. + + Returns: + Callable[[~.GetAnnotationSpecRequest], + Awaitable[~.AnnotationSpec]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_annotation_spec' not in self._stubs: + self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', + request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, + response_deserializer=annotation_spec.AnnotationSpec.deserialize, + ) + return self._stubs['get_annotation_spec'] + + @property + def list_annotations(self) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse]]: + r"""Return a callable for the list annotations method over gRPC. + + Lists Annotations belongs to a dataitem + + Returns: + Callable[[~.ListAnnotationsRequest], + Awaitable[~.ListAnnotationsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_annotations' not in self._stubs: + self._stubs['list_annotations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', + request_serializer=dataset_service.ListAnnotationsRequest.serialize, + response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, + ) + return self._stubs['list_annotations'] + + +__all__ = ( + 'DatasetServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py index af0b93f5a8..e4f3dcfbcf 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py @@ -16,5 +16,9 @@ # from .client import EndpointServiceClient +from .async_client import EndpointServiceAsyncClient -__all__ = ("EndpointServiceClient",) +__all__ = ( + 'EndpointServiceClient', + 'EndpointServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py new file mode 100644 index 0000000000..972cb90855 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -0,0 +1,837 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.types import endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport +from .client import EndpointServiceClient + + +class EndpointServiceAsyncClient: + """""" + + _client: EndpointServiceClient + + DEFAULT_ENDPOINT = EndpointServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = EndpointServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(EndpointServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(EndpointServiceClient.parse_endpoint_path) + model_path = staticmethod(EndpointServiceClient.model_path) + parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) + + common_billing_account_path = staticmethod(EndpointServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(EndpointServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(EndpointServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(EndpointServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(EndpointServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(EndpointServiceClient.common_project_path) + parse_common_project_path = staticmethod(EndpointServiceClient.parse_common_project_path) + + common_location_path = staticmethod(EndpointServiceClient.common_location_path) + parse_common_location_path = staticmethod(EndpointServiceClient.parse_common_location_path) + + from_service_account_file = EndpointServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> EndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + EndpointServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the endpoint service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.EndpointServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = EndpointServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_endpoint(self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates an Endpoint. + + Args: + request (:class:`~.endpoint_service.CreateEndpointRequest`): + The request object. Request message for + ``EndpointService.CreateEndpoint``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Endpoint in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + endpoint (:class:`~.gca_endpoint.Endpoint`): + Required. The Endpoint to create. + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_endpoint.Endpoint`: Models are deployed + into it, and afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, endpoint]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = endpoint_service.CreateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if endpoint is not None: + request.endpoint = endpoint + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_endpoint.Endpoint, + metadata_type=endpoint_service.CreateEndpointOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_endpoint(self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: + r"""Gets an Endpoint. + + Args: + request (:class:`~.endpoint_service.GetEndpointRequest`): + The request object. Request message for + ``EndpointService.GetEndpoint`` + name (:class:`str`): + Required. The name of the Endpoint resource. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.endpoint.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = endpoint_service.GetEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_endpoints(self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: + r"""Lists Endpoints in a Location. + + Args: + request (:class:`~.endpoint_service.ListEndpointsRequest`): + The request object. Request message for + ``EndpointService.ListEndpoints``. + parent (:class:`str`): + Required. The resource name of the Location from which + to list the Endpoints. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListEndpointsAsyncPager: + Response message for + ``EndpointService.ListEndpoints``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = endpoint_service.ListEndpointsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_endpoints, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListEndpointsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_endpoint(self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: + r"""Updates an Endpoint. + + Args: + request (:class:`~.endpoint_service.UpdateEndpointRequest`): + The request object. Request message for + ``EndpointService.UpdateEndpoint``. + endpoint (:class:`~.gca_endpoint.Endpoint`): + Required. The Endpoint which replaces + the resource on the server. + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to + the resource. + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_endpoint.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([endpoint, update_mask]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = endpoint_service.UpdateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint.name', request.endpoint.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_endpoint(self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an Endpoint. + + Args: + request (:class:`~.endpoint_service.DeleteEndpointRequest`): + The request object. Request message for + ``EndpointService.DeleteEndpoint``. + name (:class:`str`): + Required. The name of the Endpoint resource to be + deleted. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = endpoint_service.DeleteEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def deploy_model(self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Args: + request (:class:`~.endpoint_service.DeployModelRequest`): + The request object. Request message for + ``EndpointService.DeployModel``. + endpoint (:class:`str`): + Required. The name of the Endpoint resource into which + to deploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (:class:`~.gca_endpoint.DeployedModel`): + Required. The DeployedModel to be created within the + Endpoint. Note that + ``Endpoint.traffic_split`` + must be updated for the DeployedModel to start receiving + traffic, either as part of this call, or via + ``EndpointService.UpdateEndpoint``. + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (:class:`Sequence[~.endpoint_service.DeployModelRequest.TrafficSplitEntry]`): + A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that + DeployedModel. + + If this field is non-empty, then the Endpoint's + ``traffic_split`` + will be overwritten with it. To refer to the ID of the + just being deployed Model, a "0" should be used, and the + actual ID of the new DeployedModel will be filled in its + place by this method. The traffic percentage values must + add up to 100. + + If this field is empty, then the Endpoint's + ``traffic_split`` + is not updated. + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.endpoint_service.DeployModelResponse`: + Response message for + ``EndpointService.DeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([endpoint, deployed_model, traffic_split]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = endpoint_service.DeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + if traffic_split is not None: + request.traffic_split = traffic_split + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.deploy_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.DeployModelResponse, + metadata_type=endpoint_service.DeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def undeploy_model(self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Args: + request (:class:`~.endpoint_service.UndeployModelRequest`): + The request object. Request message for + ``EndpointService.UndeployModel``. + endpoint (:class:`str`): + Required. The name of the Endpoint resource from which + to undeploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (:class:`str`): + Required. The ID of the DeployedModel + to be undeployed from the Endpoint. + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (:class:`Sequence[~.endpoint_service.UndeployModelRequest.TrafficSplitEntry]`): + If this field is provided, then the Endpoint's + ``traffic_split`` + will be overwritten with it. If last DeployedModel is + being undeployed from the Endpoint, the + [Endpoint.traffic_split] will always end up empty when + this call returns. A DeployedModel will be successfully + undeployed only if it doesn't have any traffic assigned + to it when this method executes, or if this field + unassigns any traffic to it. + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.endpoint_service.UndeployModelResponse`: + Response message for + ``EndpointService.UndeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([endpoint, deployed_model_id, traffic_split]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = endpoint_service.UndeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + if traffic_split is not None: + request.traffic_split = traffic_split + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.undeploy_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.UndeployModelResponse, + metadata_type=endpoint_service.UndeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'EndpointServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 0ed3efe87a..f601c6f145 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint @@ -36,8 +43,9 @@ from google.protobuf import field_mask_pb2 as field_mask # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import EndpointServiceTransport +from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import EndpointServiceGrpcTransport +from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport class EndpointServiceClientMeta(type): @@ -47,13 +55,13 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry['grpc'] = EndpointServiceGrpcTransport + _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry["grpc"] = EndpointServiceGrpcTransport - - def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -75,8 +83,38 @@ def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTranspor class EndpointServiceClient(metaclass=EndpointServiceClientMeta): """""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -93,26 +131,105 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file + @property + def transport(self) -> EndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + EndpointServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod - def endpoint_path(project: str, location: str, endpoint: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str,str]: + """Parse a endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str,location: str,model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parse a model path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, - ) -> None: + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -124,38 +241,107 @@ def __init__( transport (Union[str, ~.EndpointServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, EndpointServiceTransport): - if credentials: + # transport is a EndpointServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: raise ValueError( "When providing a transport instance, " - "provide its credentials directly." + "provide its scopes directly." ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) - def create_endpoint( - self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_endpoint(self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates an Endpoint. Args: @@ -194,32 +380,45 @@ def create_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, endpoint]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.CreateEndpointRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if endpoint is not None: - request.endpoint = endpoint + has_flattened_params = any([parent, endpoint]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.CreateEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.CreateEndpointRequest): + request = endpoint_service.CreateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if endpoint is not None: + request.endpoint = endpoint # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_endpoint, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -232,15 +431,14 @@ def create_endpoint( # Done; return the response. return response - def get_endpoint( - self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint(self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -270,49 +468,55 @@ def get_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = endpoint_service.GetEndpointRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.GetEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.GetEndpointRequest): + request = endpoint_service.GetEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_endpoint, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_endpoint] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_endpoints( - self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints(self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -345,56 +549,65 @@ def list_endpoints( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = endpoint_service.ListEndpointsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.ListEndpointsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.ListEndpointsRequest): + request = endpoint_service.ListEndpointsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_endpoints, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_endpoints] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def update_endpoint( - self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint(self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -430,45 +643,57 @@ def update_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.UpdateEndpointRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if update_mask is not None: - request.update_mask = update_mask + has_flattened_params = any([endpoint, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.UpdateEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.UpdateEndpointRequest): + request = endpoint_service.UpdateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_endpoint, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint.name', request.endpoint.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def delete_endpoint( - self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_endpoint(self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes an Endpoint. Args: @@ -513,30 +738,43 @@ def delete_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = endpoint_service.DeleteEndpointRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.DeleteEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.DeleteEndpointRequest): + request = endpoint_service.DeleteEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_endpoint, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -549,19 +787,16 @@ def delete_endpoint( # Done; return the response. return response - def deploy_model( - self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[ - endpoint_service.DeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def deploy_model(self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -625,34 +860,48 @@ def deploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model, traffic_split]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([endpoint, deployed_model, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.DeployModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.DeployModelRequest): + request = endpoint_service.DeployModelRequest(request) - request = endpoint_service.DeployModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. - # If we have keyword arguments corresponding to fields on the - # request, apply these. + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model - if endpoint is not None: - request.endpoint = endpoint - if deployed_model is not None: - request.deployed_model = deployed_model - if traffic_split is not None: - request.traffic_split = traffic_split + if traffic_split: + request.traffic_split.extend(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.deploy_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.deploy_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -665,19 +914,16 @@ def deploy_model( # Done; return the response. return response - def undeploy_model( - self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[ - endpoint_service.UndeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def undeploy_model(self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -732,34 +978,48 @@ def undeploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model_id, traffic_split]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = endpoint_service.UndeployModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.UndeployModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.UndeployModelRequest): + request = endpoint_service.UndeployModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id - if traffic_split is not None: - request.traffic_split = traffic_split + if endpoint is not None: + request.endpoint = endpoint + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if traffic_split: + request.traffic_split.extend(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.undeploy_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.undeploy_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -773,14 +1033,21 @@ def undeploy_model( return response + + + + + try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("EndpointServiceClient",) +__all__ = ( + 'EndpointServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 4c797e56fd..50399b1826 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -38,16 +38,12 @@ class ListEndpointsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [endpoint_service.ListEndpointsRequest], - endpoint_service.ListEndpointsResponse, - ], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - ): + def __init__(self, + method: Callable[..., endpoint_service.ListEndpointsResponse], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -57,10 +53,13 @@ def __init__( The initial request object. response (:class:`~.endpoint_service.ListEndpointsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = endpoint_service.ListEndpointsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -70,7 +69,7 @@ def pages(self) -> Iterable[endpoint_service.ListEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[endpoint.Endpoint]: @@ -78,4 +77,67 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: yield from page.endpoints def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListEndpointsAsyncPager: + """A pager for iterating through ``list_endpoints`` requests. + + This class thinly wraps an initial + :class:`~.endpoint_service.ListEndpointsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``endpoints`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListEndpoints`` requests and continue to iterate + through the ``endpoints`` field on the + corresponding responses. + + All the usual :class:`~.endpoint_service.ListEndpointsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.endpoint_service.ListEndpointsRequest`): + The initial request object. + response (:class:`~.endpoint_service.ListEndpointsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = endpoint_service.ListEndpointsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[endpoint_service.ListEndpointsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[endpoint.Endpoint]: + async def async_generator(): + async for page in self.pages: + for response in page.endpoints: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py index 62eff450a6..fea1a635d6 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import EndpointServiceTransport from .grpc import EndpointServiceGrpcTransport +from .grpc_asyncio import EndpointServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] -_transport_registry["grpc"] = EndpointServiceGrpcTransport +_transport_registry['grpc'] = EndpointServiceGrpcTransport +_transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport __all__ = ( - "EndpointServiceTransport", - "EndpointServiceGrpcTransport", + 'EndpointServiceTransport', + 'EndpointServiceGrpcTransport', + 'EndpointServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index 43baa080e0..cb5e891416 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -28,17 +32,32 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class EndpointServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -48,74 +67,154 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_endpoint: gapic_v1.method.wrap_method( + self.create_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.get_endpoint: gapic_v1.method.wrap_method( + self.get_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.list_endpoints: gapic_v1.method.wrap_method( + self.list_endpoints, + default_timeout=None, + client_info=client_info, + ), + self.update_endpoint: gapic_v1.method.wrap_method( + self.update_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.delete_endpoint: gapic_v1.method.wrap_method( + self.delete_endpoint, + default_timeout=None, + client_info=client_info, + ), + self.deploy_model: gapic_v1.method.wrap_method( + self.deploy_model, + default_timeout=None, + client_info=client_info, + ), + self.undeploy_model: gapic_v1.method.wrap_method( + self.undeploy_model, + default_timeout=None, + client_info=client_info, + ), + + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property - def create_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], operations.Operation - ]: - raise NotImplementedError + def create_endpoint(self) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def get_endpoint( - self, - ) -> typing.Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: - raise NotImplementedError + def get_endpoint(self) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[ + endpoint.Endpoint, + typing.Awaitable[endpoint.Endpoint] + ]]: + raise NotImplementedError() @property - def list_endpoints( - self, - ) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse - ]: - raise NotImplementedError + def list_endpoints(self) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse] + ]]: + raise NotImplementedError() @property - def update_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint - ]: - raise NotImplementedError + def update_endpoint(self) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[ + gca_endpoint.Endpoint, + typing.Awaitable[gca_endpoint.Endpoint] + ]]: + raise NotImplementedError() @property - def delete_endpoint( - self, - ) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], operations.Operation - ]: - raise NotImplementedError + def delete_endpoint(self) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def deploy_model( - self, - ) -> typing.Callable[[endpoint_service.DeployModelRequest], operations.Operation]: - raise NotImplementedError + def deploy_model(self) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def undeploy_model( - self, - ) -> typing.Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: - raise NotImplementedError - - -__all__ = ("EndpointServiceTransport",) + def undeploy_model(self) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'EndpointServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 9bde20f31f..fbbf33b2b7 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -28,7 +32,7 @@ from google.cloud.aiplatform_v1beta1.types import endpoint_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import EndpointServiceTransport +from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO class EndpointServiceGrpcTransport(EndpointServiceTransport): @@ -41,14 +45,20 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - channel: grpc.Channel = None - ) -> None: + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -59,29 +69,107 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - **kwargs - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -90,30 +178,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property @@ -124,18 +219,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self.__dict__['operations_client'] @property - def create_endpoint( - self, - ) -> Callable[[endpoint_service.CreateEndpointRequest], operations.Operation]: + def create_endpoint(self) -> Callable[ + [endpoint_service.CreateEndpointRequest], + operations.Operation]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -150,18 +245,18 @@ def create_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_endpoint" not in self._stubs: - self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", + if 'create_endpoint' not in self._stubs: + self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_endpoint"] + return self._stubs['create_endpoint'] @property - def get_endpoint( - self, - ) -> Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: + def get_endpoint(self) -> Callable[ + [endpoint_service.GetEndpointRequest], + endpoint.Endpoint]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -176,20 +271,18 @@ def get_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_endpoint" not in self._stubs: - self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", + if 'get_endpoint' not in self._stubs: + self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs["get_endpoint"] + return self._stubs['get_endpoint'] @property - def list_endpoints( - self, - ) -> Callable[ - [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse - ]: + def list_endpoints(self) -> Callable[ + [endpoint_service.ListEndpointsRequest], + endpoint_service.ListEndpointsResponse]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -204,18 +297,18 @@ def list_endpoints( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_endpoints" not in self._stubs: - self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", + if 'list_endpoints' not in self._stubs: + self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs["list_endpoints"] + return self._stubs['list_endpoints'] @property - def update_endpoint( - self, - ) -> Callable[[endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint]: + def update_endpoint(self) -> Callable[ + [endpoint_service.UpdateEndpointRequest], + gca_endpoint.Endpoint]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -230,18 +323,18 @@ def update_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_endpoint" not in self._stubs: - self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", + if 'update_endpoint' not in self._stubs: + self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs["update_endpoint"] + return self._stubs['update_endpoint'] @property - def delete_endpoint( - self, - ) -> Callable[[endpoint_service.DeleteEndpointRequest], operations.Operation]: + def delete_endpoint(self) -> Callable[ + [endpoint_service.DeleteEndpointRequest], + operations.Operation]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -256,18 +349,18 @@ def delete_endpoint( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_endpoint" not in self._stubs: - self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", + if 'delete_endpoint' not in self._stubs: + self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_endpoint"] + return self._stubs['delete_endpoint'] @property - def deploy_model( - self, - ) -> Callable[[endpoint_service.DeployModelRequest], operations.Operation]: + def deploy_model(self) -> Callable[ + [endpoint_service.DeployModelRequest], + operations.Operation]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -283,18 +376,18 @@ def deploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "deploy_model" not in self._stubs: - self._stubs["deploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", + if 'deploy_model' not in self._stubs: + self._stubs['deploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["deploy_model"] + return self._stubs['deploy_model'] @property - def undeploy_model( - self, - ) -> Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: + def undeploy_model(self) -> Callable[ + [endpoint_service.UndeployModelRequest], + operations.Operation]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -311,13 +404,15 @@ def undeploy_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "undeploy_model" not in self._stubs: - self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", + if 'undeploy_model' not in self._stubs: + self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["undeploy_model"] + return self._stubs['undeploy_model'] -__all__ = ("EndpointServiceGrpcTransport",) +__all__ = ( + 'EndpointServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..69d7842201 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import EndpointServiceGrpcTransport + + +class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): + """gRPC AsyncIO backend transport for EndpointService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def create_endpoint(self) -> Callable[ + [endpoint_service.CreateEndpointRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create endpoint method over gRPC. + + Creates an Endpoint. + + Returns: + Callable[[~.CreateEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_endpoint' not in self._stubs: + self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', + request_serializer=endpoint_service.CreateEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_endpoint'] + + @property + def get_endpoint(self) -> Callable[ + [endpoint_service.GetEndpointRequest], + Awaitable[endpoint.Endpoint]]: + r"""Return a callable for the get endpoint method over gRPC. + + Gets an Endpoint. + + Returns: + Callable[[~.GetEndpointRequest], + Awaitable[~.Endpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_endpoint' not in self._stubs: + self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', + request_serializer=endpoint_service.GetEndpointRequest.serialize, + response_deserializer=endpoint.Endpoint.deserialize, + ) + return self._stubs['get_endpoint'] + + @property + def list_endpoints(self) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse]]: + r"""Return a callable for the list endpoints method over gRPC. + + Lists Endpoints in a Location. + + Returns: + Callable[[~.ListEndpointsRequest], + Awaitable[~.ListEndpointsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_endpoints' not in self._stubs: + self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', + request_serializer=endpoint_service.ListEndpointsRequest.serialize, + response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, + ) + return self._stubs['list_endpoints'] + + @property + def update_endpoint(self) -> Callable[ + [endpoint_service.UpdateEndpointRequest], + Awaitable[gca_endpoint.Endpoint]]: + r"""Return a callable for the update endpoint method over gRPC. + + Updates an Endpoint. + + Returns: + Callable[[~.UpdateEndpointRequest], + Awaitable[~.Endpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_endpoint' not in self._stubs: + self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', + request_serializer=endpoint_service.UpdateEndpointRequest.serialize, + response_deserializer=gca_endpoint.Endpoint.deserialize, + ) + return self._stubs['update_endpoint'] + + @property + def delete_endpoint(self) -> Callable[ + [endpoint_service.DeleteEndpointRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete endpoint method over gRPC. + + Deletes an Endpoint. + + Returns: + Callable[[~.DeleteEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_endpoint' not in self._stubs: + self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', + request_serializer=endpoint_service.DeleteEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_endpoint'] + + @property + def deploy_model(self) -> Callable[ + [endpoint_service.DeployModelRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the deploy model method over gRPC. + + Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Returns: + Callable[[~.DeployModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'deploy_model' not in self._stubs: + self._stubs['deploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', + request_serializer=endpoint_service.DeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['deploy_model'] + + @property + def undeploy_model(self) -> Callable[ + [endpoint_service.UndeployModelRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the undeploy model method over gRPC. + + Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Returns: + Callable[[~.UndeployModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'undeploy_model' not in self._stubs: + self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', + request_serializer=endpoint_service.UndeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['undeploy_model'] + + +__all__ = ( + 'EndpointServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py index bf1248d281..037407b714 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py @@ -16,5 +16,9 @@ # from .client import JobServiceClient +from .async_client import JobServiceAsyncClient -__all__ = ("JobServiceClient",) +__all__ = ( + 'JobServiceClient', + 'JobServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py new file mode 100644 index 0000000000..d309df53a5 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -0,0 +1,1905 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.job_service import pagers +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import completion_stats +from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import study +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + +from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport +from .client import JobServiceClient + + +class JobServiceAsyncClient: + """A service for creating and managing AI Platform's jobs.""" + + _client: JobServiceClient + + DEFAULT_ENDPOINT = JobServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT + + batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) + parse_batch_prediction_job_path = staticmethod(JobServiceClient.parse_batch_prediction_job_path) + custom_job_path = staticmethod(JobServiceClient.custom_job_path) + parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) + data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) + parse_data_labeling_job_path = staticmethod(JobServiceClient.parse_data_labeling_job_path) + dataset_path = staticmethod(JobServiceClient.dataset_path) + parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) + hyperparameter_tuning_job_path = staticmethod(JobServiceClient.hyperparameter_tuning_job_path) + parse_hyperparameter_tuning_job_path = staticmethod(JobServiceClient.parse_hyperparameter_tuning_job_path) + model_path = staticmethod(JobServiceClient.model_path) + parse_model_path = staticmethod(JobServiceClient.parse_model_path) + + common_billing_account_path = staticmethod(JobServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(JobServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(JobServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(JobServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(JobServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(JobServiceClient.common_project_path) + parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) + + common_location_path = staticmethod(JobServiceClient.common_location_path) + parse_common_location_path = staticmethod(JobServiceClient.parse_common_location_path) + + from_service_account_file = JobServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> JobServiceTransport: + """Return the transport used by the client instance. + + Returns: + JobServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(JobServiceClient).get_transport_class, type(JobServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the job service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.JobServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = JobServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_custom_job(self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: + r"""Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Args: + request (:class:`~.job_service.CreateCustomJobRequest`): + The request object. Request message for + ``JobService.CreateCustomJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the CustomJob in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + custom_job (:class:`~.gca_custom_job.CustomJob`): + Required. The CustomJob to create. + This corresponds to the ``custom_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_custom_job.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, custom_job]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CreateCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if custom_job is not None: + request.custom_job = custom_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_custom_job(self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: + r"""Gets a CustomJob. + + Args: + request (:class:`~.job_service.GetCustomJobRequest`): + The request object. Request message for + ``JobService.GetCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob resource. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.custom_job.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.GetCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_custom_jobs(self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: + r"""Lists CustomJobs in a Location. + + Args: + request (:class:`~.job_service.ListCustomJobsRequest`): + The request object. Request message for + ``JobService.ListCustomJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + CustomJobs from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListCustomJobsAsyncPager: + Response message for + ``JobService.ListCustomJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.ListCustomJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_custom_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListCustomJobsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_custom_job(self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a CustomJob. + + Args: + request (:class:`~.job_service.DeleteCustomJobRequest`): + The request object. Request message for + ``JobService.DeleteCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob resource to be + deleted. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.DeleteCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_custom_job(self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`~.job_service.CancelCustomJobRequest`): + The request object. Request message for + ``JobService.CancelCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob to cancel. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CancelCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_data_labeling_job(self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: + r"""Creates a DataLabelingJob. + + Args: + request (:class:`~.job_service.CreateDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.CreateDataLabelingJob][]. + parent (:class:`str`): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + data_labeling_job (:class:`~.gca_data_labeling_job.DataLabelingJob`): + Required. The DataLabelingJob to + create. + This corresponds to the ``data_labeling_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_data_labeling_job.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, data_labeling_job]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CreateDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if data_labeling_job is not None: + request.data_labeling_job = data_labeling_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_data_labeling_job(self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: + r"""Gets a DataLabelingJob. + + Args: + request (:class:`~.job_service.GetDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.GetDataLabelingJob][]. + name (:class:`str`): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.data_labeling_job.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.GetDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_data_labeling_jobs(self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: + r"""Lists DataLabelingJobs in a Location. + + Args: + request (:class:`~.job_service.ListDataLabelingJobsRequest`): + The request object. Request message for + [DataLabelingJobService.ListDataLabelingJobs][]. + parent (:class:`str`): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListDataLabelingJobsAsyncPager: + Response message for + ``JobService.ListDataLabelingJobs``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.ListDataLabelingJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_data_labeling_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDataLabelingJobsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_data_labeling_job(self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a DataLabelingJob. + + Args: + request (:class:`~.job_service.DeleteDataLabelingJobRequest`): + The request object. Request message for + ``JobService.DeleteDataLabelingJob``. + name (:class:`str`): + Required. The name of the DataLabelingJob to be deleted. + Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.DeleteDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_data_labeling_job(self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Args: + request (:class:`~.job_service.CancelDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.CancelDataLabelingJob][]. + name (:class:`str`): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CancelDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_hyperparameter_tuning_job(self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Creates a HyperparameterTuningJob + + Args: + request (:class:`~.job_service.CreateHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.CreateHyperparameterTuningJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the HyperparameterTuningJob in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + hyperparameter_tuning_job (:class:`~.gca_hyperparameter_tuning_job.HyperparameterTuningJob`): + Required. The HyperparameterTuningJob + to create. + This corresponds to the ``hyperparameter_tuning_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_hyperparameter_tuning_job.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, hyperparameter_tuning_job]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CreateHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if hyperparameter_tuning_job is not None: + request.hyperparameter_tuning_job = hyperparameter_tuning_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_hyperparameter_tuning_job(self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Gets a HyperparameterTuningJob + + Args: + request (:class:`~.job_service.GetHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.GetHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob + resource. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.hyperparameter_tuning_job.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.GetHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_hyperparameter_tuning_jobs(self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + r"""Lists HyperparameterTuningJobs in a Location. + + Args: + request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + The request object. Request message for + ``JobService.ListHyperparameterTuningJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + HyperparameterTuningJobs from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListHyperparameterTuningJobsAsyncPager: + Response message for + ``JobService.ListHyperparameterTuningJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.ListHyperparameterTuningJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListHyperparameterTuningJobsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_hyperparameter_tuning_job(self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a HyperparameterTuningJob. + + Args: + request (:class:`~.job_service.DeleteHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.DeleteHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob + resource to be deleted. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.DeleteHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_hyperparameter_tuning_job(self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`~.job_service.CancelHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.CancelHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob to + cancel. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CancelHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def create_batch_prediction_job(self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: + r"""Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Args: + request (:class:`~.job_service.CreateBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.CreateBatchPredictionJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the BatchPredictionJob in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + batch_prediction_job (:class:`~.gca_batch_prediction_job.BatchPredictionJob`): + Required. The BatchPredictionJob to + create. + This corresponds to the ``batch_prediction_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_batch_prediction_job.BatchPredictionJob: + A job that uses a + ``Model`` + to produce predictions on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the instances + fail, the job may finish without attempting predictions + for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, batch_prediction_job]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CreateBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if batch_prediction_job is not None: + request.batch_prediction_job = batch_prediction_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_batch_prediction_job(self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: + r"""Gets a BatchPredictionJob + + Args: + request (:class:`~.job_service.GetBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.GetBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob resource. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.batch_prediction_job.BatchPredictionJob: + A job that uses a + ``Model`` + to produce predictions on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the instances + fail, the job may finish without attempting predictions + for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.GetBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_batch_prediction_jobs(self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: + r"""Lists BatchPredictionJobs in a Location. + + Args: + request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + The request object. Request message for + ``JobService.ListBatchPredictionJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + BatchPredictionJobs from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListBatchPredictionJobsAsyncPager: + Response message for + ``JobService.ListBatchPredictionJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.ListBatchPredictionJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_batch_prediction_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListBatchPredictionJobsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_batch_prediction_job(self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Args: + request (:class:`~.job_service.DeleteBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.DeleteBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.DeleteBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_batch_prediction_job(self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Args: + request (:class:`~.job_service.CancelBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.CancelBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob to cancel. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = job_service.CancelBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'JobServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index b56a9a7871..ca78580819 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -16,33 +16,34 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources @@ -55,8 +56,9 @@ from google.rpc import status_pb2 as status # type: ignore from google.type import money_pb2 as money # type: ignore -from .transports.base import JobServiceTransport +from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import JobServiceGrpcTransport +from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport class JobServiceClientMeta(type): @@ -66,11 +68,13 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry['grpc'] = JobServiceGrpcTransport + _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -92,8 +96,38 @@ def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: class JobServiceClient(metaclass=JobServiceClientMeta): """A service for creating and managing AI Platform's jobs.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -110,57 +144,149 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file + @property + def transport(self) -> JobServiceTransport: + """Return the transport used by the client instance. + + Returns: + JobServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: + """Return a fully-qualified batch_prediction_job string.""" + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + + @staticmethod + def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: + """Parse a batch_prediction_job path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) + return m.groupdict() if m else {} + @staticmethod - def custom_job_path(project: str, location: str, custom_job: str,) -> str: + def custom_job_path(project: str,location: str,custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) @staticmethod - def batch_prediction_job_path( - project: str, location: str, batch_prediction_job: str, - ) -> str: - """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, - location=location, - batch_prediction_job=batch_prediction_job, - ) + def parse_custom_job_path(path: str) -> Dict[str,str]: + """Parse a custom_job path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: + """Return a fully-qualified data_labeling_job string.""" + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + + @staticmethod + def parse_data_labeling_job_path(path: str) -> Dict[str,str]: + """Parse a data_labeling_job path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) + return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path( - project: str, location: str, hyperparameter_tuning_job: str, - ) -> str: + def dataset_path(project: str,location: str,dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str,str]: + """Parse a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) @staticmethod - def data_labeling_job_path( - project: str, location: str, data_labeling_job: str, - ) -> str: - """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: + """Parse a hyperparameter_tuning_job path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str,location: str,model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parse a model path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, - ) -> None: + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -172,38 +298,107 @@ def __init__( transport (Union[str, ~.JobServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, JobServiceTransport): - if credentials: + # transport is a JobServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: raise ValueError( "When providing a transport instance, " - "provide its credentials directly." + "provide its scopes directly." ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) - def create_custom_job( - self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job(self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -245,45 +440,57 @@ def create_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, custom_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateCustomJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if custom_job is not None: - request.custom_job = custom_job + has_flattened_params = any([parent, custom_job]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateCustomJobRequest): + request = job_service.CreateCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if custom_job is not None: + request.custom_job = custom_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_custom_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_custom_job( - self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job(self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -318,49 +525,55 @@ def get_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.GetCustomJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetCustomJobRequest): + request = job_service.GetCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_custom_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_custom_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_custom_jobs( - self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs(self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -393,55 +606,64 @@ def list_custom_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.ListCustomJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListCustomJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListCustomJobsRequest): + request = job_service.ListCustomJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_custom_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_custom_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_custom_job( - self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_custom_job(self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a CustomJob. Args: @@ -486,30 +708,43 @@ def delete_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.DeleteCustomJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteCustomJobRequest): + request = job_service.DeleteCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_custom_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -522,15 +757,14 @@ def delete_custom_job( # Done; return the response. return response - def cancel_custom_job( - self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job(self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -565,43 +799,53 @@ def cancel_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.CancelCustomJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelCustomJobRequest): + request = job_service.CancelCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_custom_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def create_data_labeling_job( - self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + def create_data_labeling_job(self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -637,45 +881,57 @@ def create_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, data_labeling_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateDataLabelingJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if data_labeling_job is not None: - request.data_labeling_job = data_labeling_job + has_flattened_params = any([parent, data_labeling_job]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateDataLabelingJobRequest): + request = job_service.CreateDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if data_labeling_job is not None: + request.data_labeling_job = data_labeling_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_data_labeling_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_data_labeling_job( - self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job(self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -706,49 +962,55 @@ def get_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.GetDataLabelingJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetDataLabelingJobRequest): + request = job_service.GetDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_data_labeling_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_data_labeling_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_data_labeling_jobs( - self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs(self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -780,55 +1042,64 @@ def list_data_labeling_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.ListDataLabelingJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListDataLabelingJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListDataLabelingJobsRequest): + request = job_service.ListDataLabelingJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_data_labeling_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_data_labeling_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job( - self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_data_labeling_job(self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -874,30 +1145,43 @@ def delete_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.DeleteDataLabelingJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteDataLabelingJobRequest): + request = job_service.DeleteDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_data_labeling_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -910,15 +1194,14 @@ def delete_data_labeling_job( # Done; return the response. return response - def cancel_data_labeling_job( - self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job(self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -943,43 +1226,53 @@ def cancel_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.CancelDataLabelingJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelDataLabelingJobRequest): + request = job_service.CancelDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_data_labeling_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def create_hyperparameter_tuning_job( - self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + def create_hyperparameter_tuning_job(self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1017,45 +1310,57 @@ def create_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, hyperparameter_tuning_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateHyperparameterTuningJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if hyperparameter_tuning_job is not None: - request.hyperparameter_tuning_job = hyperparameter_tuning_job + has_flattened_params = any([parent, hyperparameter_tuning_job]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): + request = job_service.CreateHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if hyperparameter_tuning_job is not None: + request.hyperparameter_tuning_job = hyperparameter_tuning_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_hyperparameter_tuning_job( - self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job(self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1088,49 +1393,55 @@ def get_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.GetHyperparameterTuningJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): + request = job_service.GetHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_hyperparameter_tuning_jobs( - self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs(self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1163,55 +1474,64 @@ def list_hyperparameter_tuning_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.ListHyperparameterTuningJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListHyperparameterTuningJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): + request = job_service.ListHyperparameterTuningJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_hyperparameter_tuning_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job( - self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_hyperparameter_tuning_job(self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1257,30 +1577,43 @@ def delete_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.DeleteHyperparameterTuningJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): + request = job_service.DeleteHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1293,15 +1626,14 @@ def delete_hyperparameter_tuning_job( # Done; return the response. return response - def cancel_hyperparameter_tuning_job( - self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job(self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1339,43 +1671,53 @@ def cancel_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.CancelHyperparameterTuningJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): + request = job_service.CancelHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) - def create_batch_prediction_job( - self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + def create_batch_prediction_job(self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1417,45 +1759,57 @@ def create_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, batch_prediction_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateBatchPredictionJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if batch_prediction_job is not None: - request.batch_prediction_job = batch_prediction_job + has_flattened_params = any([parent, batch_prediction_job]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateBatchPredictionJobRequest): + request = job_service.CreateBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if batch_prediction_job is not None: + request.batch_prediction_job = batch_prediction_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_batch_prediction_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_batch_prediction_job( - self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job(self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1491,49 +1845,55 @@ def get_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.GetBatchPredictionJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetBatchPredictionJobRequest): + request = job_service.GetBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_batch_prediction_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_batch_prediction_jobs( - self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs(self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1566,55 +1926,64 @@ def list_batch_prediction_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.ListBatchPredictionJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListBatchPredictionJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListBatchPredictionJobsRequest): + request = job_service.ListBatchPredictionJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_batch_prediction_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job( - self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_batch_prediction_job(self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -1661,30 +2030,43 @@ def delete_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.DeleteBatchPredictionJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): + request = job_service.DeleteBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_batch_prediction_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1697,15 +2079,14 @@ def delete_batch_prediction_job( # Done; return the response. return response - def cancel_batch_prediction_job( - self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job(self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -1741,42 +2122,60 @@ def cancel_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = job_service.CancelBatchPredictionJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelBatchPredictionJobRequest): + request = job_service.CancelBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_batch_prediction_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("JobServiceClient",) +__all__ = ( + 'JobServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 0dd74763fa..17cf187f9e 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job @@ -41,15 +41,12 @@ class ListCustomJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse - ], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - ): + def __init__(self, + method: Callable[..., job_service.ListCustomJobsResponse], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -59,10 +56,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListCustomJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListCustomJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -72,7 +72,7 @@ def pages(self) -> Iterable[job_service.ListCustomJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[custom_job.CustomJob]: @@ -80,7 +80,70 @@ def __iter__(self) -> Iterable[custom_job.CustomJob]: yield from page.custom_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListCustomJobsAsyncPager: + """A pager for iterating through ``list_custom_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListCustomJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``custom_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListCustomJobs`` requests and continue to iterate + through the ``custom_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListCustomJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListCustomJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListCustomJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListCustomJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListCustomJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[custom_job.CustomJob]: + async def async_generator(): + async for page in self.pages: + for response in page.custom_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListDataLabelingJobsPager: @@ -100,16 +163,12 @@ class ListDataLabelingJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse, - ], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - ): + def __init__(self, + method: Callable[..., job_service.ListDataLabelingJobsResponse], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -119,10 +178,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListDataLabelingJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListDataLabelingJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -132,7 +194,7 @@ def pages(self) -> Iterable[job_service.ListDataLabelingJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: @@ -140,7 +202,70 @@ def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: yield from page.data_labeling_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListDataLabelingJobsAsyncPager: + """A pager for iterating through ``list_data_labeling_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListDataLabelingJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``data_labeling_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDataLabelingJobs`` requests and continue to iterate + through the ``data_labeling_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListDataLabelingJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListDataLabelingJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListDataLabelingJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListDataLabelingJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListDataLabelingJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[data_labeling_job.DataLabelingJob]: + async def async_generator(): + async for page in self.pages: + for response in page.data_labeling_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsPager: @@ -160,16 +285,12 @@ class ListHyperparameterTuningJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse, - ], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - ): + def __init__(self, + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -179,10 +300,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListHyperparameterTuningJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -192,7 +316,7 @@ def pages(self) -> Iterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob]: @@ -200,7 +324,70 @@ def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob yield from page.hyperparameter_tuning_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListHyperparameterTuningJobsAsyncPager: + """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListHyperparameterTuningJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``hyperparameter_tuning_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListHyperparameterTuningJobs`` requests and continue to iterate + through the ``hyperparameter_tuning_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListHyperparameterTuningJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[job_service.ListHyperparameterTuningJobsResponse]], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListHyperparameterTuningJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + async def async_generator(): + async for page in self.pages: + for response in page.hyperparameter_tuning_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListBatchPredictionJobsPager: @@ -220,16 +407,12 @@ class ListBatchPredictionJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse, - ], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - ): + def __init__(self, + method: Callable[..., job_service.ListBatchPredictionJobsResponse], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -239,10 +422,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListBatchPredictionJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListBatchPredictionJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -252,7 +438,7 @@ def pages(self) -> Iterable[job_service.ListBatchPredictionJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: @@ -260,4 +446,67 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: yield from page.batch_prediction_jobs def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListBatchPredictionJobsAsyncPager: + """A pager for iterating through ``list_batch_prediction_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListBatchPredictionJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``batch_prediction_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListBatchPredictionJobs`` requests and continue to iterate + through the ``batch_prediction_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListBatchPredictionJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListBatchPredictionJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListBatchPredictionJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListBatchPredictionJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[batch_prediction_job.BatchPredictionJob]: + async def async_generator(): + async for page in self.pages: + for response in page.batch_prediction_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py index 2f081266a0..f46fff0524 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import JobServiceTransport from .grpc import JobServiceGrpcTransport +from .grpc_asyncio import JobServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] -_transport_registry["grpc"] = JobServiceGrpcTransport +_transport_registry['grpc'] = JobServiceGrpcTransport +_transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport __all__ = ( - "JobServiceTransport", - "JobServiceGrpcTransport", + 'JobServiceTransport', + 'JobServiceGrpcTransport', + 'JobServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index 6e11bb87ea..5bc1354ad9 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -17,41 +17,54 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -class JobServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -61,182 +74,336 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_custom_job: gapic_v1.method.wrap_method( + self.create_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.get_custom_job: gapic_v1.method.wrap_method( + self.get_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.list_custom_jobs: gapic_v1.method.wrap_method( + self.list_custom_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_custom_job: gapic_v1.method.wrap_method( + self.delete_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_custom_job: gapic_v1.method.wrap_method( + self.cancel_custom_job, + default_timeout=None, + client_info=client_info, + ), + self.create_data_labeling_job: gapic_v1.method.wrap_method( + self.create_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.get_data_labeling_job: gapic_v1.method.wrap_method( + self.get_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.list_data_labeling_jobs: gapic_v1.method.wrap_method( + self.list_data_labeling_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_data_labeling_job: gapic_v1.method.wrap_method( + self.delete_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_data_labeling_job: gapic_v1.method.wrap_method( + self.cancel_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.create_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.get_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.list_hyperparameter_tuning_jobs: gapic_v1.method.wrap_method( + self.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.create_batch_prediction_job: gapic_v1.method.wrap_method( + self.create_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.get_batch_prediction_job: gapic_v1.method.wrap_method( + self.get_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.list_batch_prediction_jobs: gapic_v1.method.wrap_method( + self.list_batch_prediction_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_batch_prediction_job: gapic_v1.method.wrap_method( + self.delete_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_batch_prediction_job: gapic_v1.method.wrap_method( + self.cancel_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property - def create_custom_job( - self, - ) -> typing.Callable[ - [job_service.CreateCustomJobRequest], gca_custom_job.CustomJob - ]: - raise NotImplementedError + def create_custom_job(self) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, + typing.Awaitable[gca_custom_job.CustomJob] + ]]: + raise NotImplementedError() @property - def get_custom_job( - self, - ) -> typing.Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: - raise NotImplementedError + def get_custom_job(self) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[ + custom_job.CustomJob, + typing.Awaitable[custom_job.CustomJob] + ]]: + raise NotImplementedError() @property - def list_custom_jobs( - self, - ) -> typing.Callable[ - [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse - ]: - raise NotImplementedError + def list_custom_jobs(self) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse] + ]]: + raise NotImplementedError() @property - def delete_custom_job( - self, - ) -> typing.Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: - raise NotImplementedError + def delete_custom_job(self) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def cancel_custom_job( - self, - ) -> typing.Callable[[job_service.CancelCustomJobRequest], empty.Empty]: - raise NotImplementedError + def cancel_custom_job(self) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: + raise NotImplementedError() @property - def create_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob, - ]: - raise NotImplementedError + def create_data_labeling_job(self) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob] + ]]: + raise NotImplementedError() @property - def get_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob - ]: - raise NotImplementedError + def get_data_labeling_job(self) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob] + ]]: + raise NotImplementedError() @property - def list_data_labeling_jobs( - self, - ) -> typing.Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse, - ]: - raise NotImplementedError + def list_data_labeling_jobs(self) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse] + ]]: + raise NotImplementedError() @property - def delete_data_labeling_job( - self, - ) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], operations.Operation - ]: - raise NotImplementedError + def delete_data_labeling_job(self) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def cancel_data_labeling_job( - self, - ) -> typing.Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: - raise NotImplementedError + def cancel_data_labeling_job(self) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: + raise NotImplementedError() @property - def create_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - ]: - raise NotImplementedError + def create_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob] + ]]: + raise NotImplementedError() @property - def get_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob, - ]: - raise NotImplementedError + def get_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob] + ]]: + raise NotImplementedError() @property - def list_hyperparameter_tuning_jobs( - self, - ) -> typing.Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse, - ]: - raise NotImplementedError + def list_hyperparameter_tuning_jobs(self) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ]]: + raise NotImplementedError() @property - def delete_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation - ]: - raise NotImplementedError + def delete_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def cancel_hyperparameter_tuning_job( - self, - ) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], empty.Empty - ]: - raise NotImplementedError + def cancel_hyperparameter_tuning_job(self) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: + raise NotImplementedError() @property - def create_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob, - ]: - raise NotImplementedError + def create_batch_prediction_job(self) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob] + ]]: + raise NotImplementedError() @property - def get_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob, - ]: - raise NotImplementedError + def get_batch_prediction_job(self) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob] + ]]: + raise NotImplementedError() - @property - def list_batch_prediction_jobs( - self, - ) -> typing.Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse, - ]: - raise NotImplementedError + @property + def list_batch_prediction_jobs(self) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse] + ]]: + raise NotImplementedError() - @property - def delete_batch_prediction_job( - self, - ) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], operations.Operation - ]: - raise NotImplementedError + @property + def delete_batch_prediction_job(self) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() - @property - def cancel_batch_prediction_job( - self, - ) -> typing.Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: - raise NotImplementedError + @property + def cancel_batch_prediction_job(self) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: + raise NotImplementedError() -__all__ = ("JobServiceTransport",) +__all__ = ( + 'JobServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index a598c180cf..8523b62d35 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -15,33 +15,31 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import JobServiceTransport +from .base import JobServiceTransport, DEFAULT_CLIENT_INFO class JobServiceGrpcTransport(JobServiceTransport): @@ -56,14 +54,20 @@ class JobServiceGrpcTransport(JobServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - channel: grpc.Channel = None - ) -> None: + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -74,29 +78,107 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - **kwargs - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -105,30 +187,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property @@ -139,18 +228,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self.__dict__['operations_client'] @property - def create_custom_job( - self, - ) -> Callable[[job_service.CreateCustomJobRequest], gca_custom_job.CustomJob]: + def create_custom_job(self) -> Callable[ + [job_service.CreateCustomJobRequest], + gca_custom_job.CustomJob]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -166,18 +255,18 @@ def create_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_custom_job" not in self._stubs: - self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", + if 'create_custom_job' not in self._stubs: + self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs["create_custom_job"] + return self._stubs['create_custom_job'] @property - def get_custom_job( - self, - ) -> Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: + def get_custom_job(self) -> Callable[ + [job_service.GetCustomJobRequest], + custom_job.CustomJob]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -192,20 +281,18 @@ def get_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_custom_job" not in self._stubs: - self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", + if 'get_custom_job' not in self._stubs: + self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs["get_custom_job"] + return self._stubs['get_custom_job'] @property - def list_custom_jobs( - self, - ) -> Callable[ - [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse - ]: + def list_custom_jobs(self) -> Callable[ + [job_service.ListCustomJobsRequest], + job_service.ListCustomJobsResponse]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -220,18 +307,18 @@ def list_custom_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_custom_jobs" not in self._stubs: - self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", + if 'list_custom_jobs' not in self._stubs: + self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs["list_custom_jobs"] + return self._stubs['list_custom_jobs'] @property - def delete_custom_job( - self, - ) -> Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: + def delete_custom_job(self) -> Callable[ + [job_service.DeleteCustomJobRequest], + operations.Operation]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -246,18 +333,18 @@ def delete_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_custom_job" not in self._stubs: - self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", + if 'delete_custom_job' not in self._stubs: + self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_custom_job"] + return self._stubs['delete_custom_job'] @property - def cancel_custom_job( - self, - ) -> Callable[[job_service.CancelCustomJobRequest], empty.Empty]: + def cancel_custom_job(self) -> Callable[ + [job_service.CancelCustomJobRequest], + empty.Empty]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -284,21 +371,18 @@ def cancel_custom_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_custom_job" not in self._stubs: - self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", + if 'cancel_custom_job' not in self._stubs: + self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_custom_job"] + return self._stubs['cancel_custom_job'] @property - def create_data_labeling_job( - self, - ) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob, - ]: + def create_data_labeling_job(self) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + gca_data_labeling_job.DataLabelingJob]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -313,20 +397,18 @@ def create_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_data_labeling_job" not in self._stubs: - self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", + if 'create_data_labeling_job' not in self._stubs: + self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["create_data_labeling_job"] + return self._stubs['create_data_labeling_job'] @property - def get_data_labeling_job( - self, - ) -> Callable[ - [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob - ]: + def get_data_labeling_job(self) -> Callable[ + [job_service.GetDataLabelingJobRequest], + data_labeling_job.DataLabelingJob]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -341,21 +423,18 @@ def get_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_data_labeling_job" not in self._stubs: - self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", + if 'get_data_labeling_job' not in self._stubs: + self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs["get_data_labeling_job"] + return self._stubs['get_data_labeling_job'] @property - def list_data_labeling_jobs( - self, - ) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse, - ]: + def list_data_labeling_jobs(self) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -370,18 +449,18 @@ def list_data_labeling_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_data_labeling_jobs" not in self._stubs: - self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", + if 'list_data_labeling_jobs' not in self._stubs: + self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs["list_data_labeling_jobs"] + return self._stubs['list_data_labeling_jobs'] @property - def delete_data_labeling_job( - self, - ) -> Callable[[job_service.DeleteDataLabelingJobRequest], operations.Operation]: + def delete_data_labeling_job(self) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], + operations.Operation]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -396,18 +475,18 @@ def delete_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_data_labeling_job" not in self._stubs: - self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", + if 'delete_data_labeling_job' not in self._stubs: + self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_data_labeling_job"] + return self._stubs['delete_data_labeling_job'] @property - def cancel_data_labeling_job( - self, - ) -> Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: + def cancel_data_labeling_job(self) -> Callable[ + [job_service.CancelDataLabelingJobRequest], + empty.Empty]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -423,21 +502,18 @@ def cancel_data_labeling_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_data_labeling_job" not in self._stubs: - self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", + if 'cancel_data_labeling_job' not in self._stubs: + self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_data_labeling_job"] + return self._stubs['cancel_data_labeling_job'] @property - def create_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - ]: + def create_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + gca_hyperparameter_tuning_job.HyperparameterTuningJob]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -453,23 +529,18 @@ def create_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "create_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", + if 'create_hyperparameter_tuning_job' not in self._stubs: + self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["create_hyperparameter_tuning_job"] + return self._stubs['create_hyperparameter_tuning_job'] @property - def get_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob, - ]: + def get_hyperparameter_tuning_job(self) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + hyperparameter_tuning_job.HyperparameterTuningJob]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -484,23 +555,18 @@ def get_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "get_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", + if 'get_hyperparameter_tuning_job' not in self._stubs: + self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs["get_hyperparameter_tuning_job"] + return self._stubs['get_hyperparameter_tuning_job'] @property - def list_hyperparameter_tuning_jobs( - self, - ) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse, - ]: + def list_hyperparameter_tuning_jobs(self) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -516,22 +582,18 @@ def list_hyperparameter_tuning_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_hyperparameter_tuning_jobs" not in self._stubs: - self._stubs[ - "list_hyperparameter_tuning_jobs" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", + if 'list_hyperparameter_tuning_jobs' not in self._stubs: + self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs["list_hyperparameter_tuning_jobs"] + return self._stubs['list_hyperparameter_tuning_jobs'] @property - def delete_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation - ]: + def delete_hyperparameter_tuning_job(self) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + operations.Operation]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -547,20 +609,18 @@ def delete_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "delete_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", + if 'delete_hyperparameter_tuning_job' not in self._stubs: + self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_hyperparameter_tuning_job"] + return self._stubs['delete_hyperparameter_tuning_job'] @property - def cancel_hyperparameter_tuning_job( - self, - ) -> Callable[[job_service.CancelHyperparameterTuningJobRequest], empty.Empty]: + def cancel_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + empty.Empty]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -589,23 +649,18 @@ def cancel_hyperparameter_tuning_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "cancel_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", + if 'cancel_hyperparameter_tuning_job' not in self._stubs: + self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_hyperparameter_tuning_job"] + return self._stubs['cancel_hyperparameter_tuning_job'] @property - def create_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob, - ]: + def create_batch_prediction_job(self) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + gca_batch_prediction_job.BatchPredictionJob]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -621,21 +676,18 @@ def create_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_batch_prediction_job" not in self._stubs: - self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", + if 'create_batch_prediction_job' not in self._stubs: + self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["create_batch_prediction_job"] + return self._stubs['create_batch_prediction_job'] @property - def get_batch_prediction_job( - self, - ) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob, - ]: + def get_batch_prediction_job(self) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + batch_prediction_job.BatchPredictionJob]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -650,21 +702,18 @@ def get_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_batch_prediction_job" not in self._stubs: - self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", + if 'get_batch_prediction_job' not in self._stubs: + self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs["get_batch_prediction_job"] + return self._stubs['get_batch_prediction_job'] @property - def list_batch_prediction_jobs( - self, - ) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse, - ]: + def list_batch_prediction_jobs(self) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -679,18 +728,18 @@ def list_batch_prediction_jobs( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_batch_prediction_jobs" not in self._stubs: - self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", + if 'list_batch_prediction_jobs' not in self._stubs: + self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs["list_batch_prediction_jobs"] + return self._stubs['list_batch_prediction_jobs'] @property - def delete_batch_prediction_job( - self, - ) -> Callable[[job_service.DeleteBatchPredictionJobRequest], operations.Operation]: + def delete_batch_prediction_job(self) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], + operations.Operation]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -706,18 +755,18 @@ def delete_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_batch_prediction_job" not in self._stubs: - self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", + if 'delete_batch_prediction_job' not in self._stubs: + self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_batch_prediction_job"] + return self._stubs['delete_batch_prediction_job'] @property - def cancel_batch_prediction_job( - self, - ) -> Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: + def cancel_batch_prediction_job(self) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], + empty.Empty]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -743,13 +792,15 @@ def cancel_batch_prediction_job( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_batch_prediction_job" not in self._stubs: - self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", + if 'cancel_batch_prediction_job' not in self._stubs: + self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_batch_prediction_job"] + return self._stubs['cancel_batch_prediction_job'] -__all__ = ("JobServiceGrpcTransport",) +__all__ = ( + 'JobServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..ac8e04e542 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -0,0 +1,811 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import job_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import JobServiceGrpcTransport + + +class JobServiceGrpcAsyncIOTransport(JobServiceTransport): + """gRPC AsyncIO backend transport for JobService. + + A service for creating and managing AI Platform's jobs. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def create_custom_job(self) -> Callable[ + [job_service.CreateCustomJobRequest], + Awaitable[gca_custom_job.CustomJob]]: + r"""Return a callable for the create custom job method over gRPC. + + Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Returns: + Callable[[~.CreateCustomJobRequest], + Awaitable[~.CustomJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_custom_job' not in self._stubs: + self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', + request_serializer=job_service.CreateCustomJobRequest.serialize, + response_deserializer=gca_custom_job.CustomJob.deserialize, + ) + return self._stubs['create_custom_job'] + + @property + def get_custom_job(self) -> Callable[ + [job_service.GetCustomJobRequest], + Awaitable[custom_job.CustomJob]]: + r"""Return a callable for the get custom job method over gRPC. + + Gets a CustomJob. + + Returns: + Callable[[~.GetCustomJobRequest], + Awaitable[~.CustomJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_custom_job' not in self._stubs: + self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', + request_serializer=job_service.GetCustomJobRequest.serialize, + response_deserializer=custom_job.CustomJob.deserialize, + ) + return self._stubs['get_custom_job'] + + @property + def list_custom_jobs(self) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse]]: + r"""Return a callable for the list custom jobs method over gRPC. + + Lists CustomJobs in a Location. + + Returns: + Callable[[~.ListCustomJobsRequest], + Awaitable[~.ListCustomJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_custom_jobs' not in self._stubs: + self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', + request_serializer=job_service.ListCustomJobsRequest.serialize, + response_deserializer=job_service.ListCustomJobsResponse.deserialize, + ) + return self._stubs['list_custom_jobs'] + + @property + def delete_custom_job(self) -> Callable[ + [job_service.DeleteCustomJobRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete custom job method over gRPC. + + Deletes a CustomJob. + + Returns: + Callable[[~.DeleteCustomJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_custom_job' not in self._stubs: + self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', + request_serializer=job_service.DeleteCustomJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_custom_job'] + + @property + def cancel_custom_job(self) -> Callable[ + [job_service.CancelCustomJobRequest], + Awaitable[empty.Empty]]: + r"""Return a callable for the cancel custom job method over gRPC. + + Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelCustomJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'cancel_custom_job' not in self._stubs: + self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', + request_serializer=job_service.CancelCustomJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['cancel_custom_job'] + + @property + def create_data_labeling_job(self) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob]]: + r"""Return a callable for the create data labeling job method over gRPC. + + Creates a DataLabelingJob. + + Returns: + Callable[[~.CreateDataLabelingJobRequest], + Awaitable[~.DataLabelingJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_data_labeling_job' not in self._stubs: + self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', + request_serializer=job_service.CreateDataLabelingJobRequest.serialize, + response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs['create_data_labeling_job'] + + @property + def get_data_labeling_job(self) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob]]: + r"""Return a callable for the get data labeling job method over gRPC. + + Gets a DataLabelingJob. + + Returns: + Callable[[~.GetDataLabelingJobRequest], + Awaitable[~.DataLabelingJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_data_labeling_job' not in self._stubs: + self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', + request_serializer=job_service.GetDataLabelingJobRequest.serialize, + response_deserializer=data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs['get_data_labeling_job'] + + @property + def list_data_labeling_jobs(self) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse]]: + r"""Return a callable for the list data labeling jobs method over gRPC. + + Lists DataLabelingJobs in a Location. + + Returns: + Callable[[~.ListDataLabelingJobsRequest], + Awaitable[~.ListDataLabelingJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_data_labeling_jobs' not in self._stubs: + self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', + request_serializer=job_service.ListDataLabelingJobsRequest.serialize, + response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, + ) + return self._stubs['list_data_labeling_jobs'] + + @property + def delete_data_labeling_job(self) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete data labeling job method over gRPC. + + Deletes a DataLabelingJob. + + Returns: + Callable[[~.DeleteDataLabelingJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_data_labeling_job' not in self._stubs: + self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', + request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_data_labeling_job'] + + @property + def cancel_data_labeling_job(self) -> Callable[ + [job_service.CancelDataLabelingJobRequest], + Awaitable[empty.Empty]]: + r"""Return a callable for the cancel data labeling job method over gRPC. + + Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Returns: + Callable[[~.CancelDataLabelingJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'cancel_data_labeling_job' not in self._stubs: + self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', + request_serializer=job_service.CancelDataLabelingJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['cancel_data_labeling_job'] + + @property + def create_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob]]: + r"""Return a callable for the create hyperparameter tuning + job method over gRPC. + + Creates a HyperparameterTuningJob + + Returns: + Callable[[~.CreateHyperparameterTuningJobRequest], + Awaitable[~.HyperparameterTuningJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_hyperparameter_tuning_job' not in self._stubs: + self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', + request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, + response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs['create_hyperparameter_tuning_job'] + + @property + def get_hyperparameter_tuning_job(self) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob]]: + r"""Return a callable for the get hyperparameter tuning job method over gRPC. + + Gets a HyperparameterTuningJob + + Returns: + Callable[[~.GetHyperparameterTuningJobRequest], + Awaitable[~.HyperparameterTuningJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_hyperparameter_tuning_job' not in self._stubs: + self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', + request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, + response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs['get_hyperparameter_tuning_job'] + + @property + def list_hyperparameter_tuning_jobs(self) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse]]: + r"""Return a callable for the list hyperparameter tuning + jobs method over gRPC. + + Lists HyperparameterTuningJobs in a Location. + + Returns: + Callable[[~.ListHyperparameterTuningJobsRequest], + Awaitable[~.ListHyperparameterTuningJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_hyperparameter_tuning_jobs' not in self._stubs: + self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', + request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, + response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, + ) + return self._stubs['list_hyperparameter_tuning_jobs'] + + @property + def delete_hyperparameter_tuning_job(self) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete hyperparameter tuning + job method over gRPC. + + Deletes a HyperparameterTuningJob. + + Returns: + Callable[[~.DeleteHyperparameterTuningJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_hyperparameter_tuning_job' not in self._stubs: + self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', + request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_hyperparameter_tuning_job'] + + @property + def cancel_hyperparameter_tuning_job(self) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + Awaitable[empty.Empty]]: + r"""Return a callable for the cancel hyperparameter tuning + job method over gRPC. + + Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelHyperparameterTuningJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'cancel_hyperparameter_tuning_job' not in self._stubs: + self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', + request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['cancel_hyperparameter_tuning_job'] + + @property + def create_batch_prediction_job(self) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob]]: + r"""Return a callable for the create batch prediction job method over gRPC. + + Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Returns: + Callable[[~.CreateBatchPredictionJobRequest], + Awaitable[~.BatchPredictionJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_batch_prediction_job' not in self._stubs: + self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', + request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, + response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs['create_batch_prediction_job'] + + @property + def get_batch_prediction_job(self) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob]]: + r"""Return a callable for the get batch prediction job method over gRPC. + + Gets a BatchPredictionJob + + Returns: + Callable[[~.GetBatchPredictionJobRequest], + Awaitable[~.BatchPredictionJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_batch_prediction_job' not in self._stubs: + self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', + request_serializer=job_service.GetBatchPredictionJobRequest.serialize, + response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs['get_batch_prediction_job'] + + @property + def list_batch_prediction_jobs(self) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse]]: + r"""Return a callable for the list batch prediction jobs method over gRPC. + + Lists BatchPredictionJobs in a Location. + + Returns: + Callable[[~.ListBatchPredictionJobsRequest], + Awaitable[~.ListBatchPredictionJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_batch_prediction_jobs' not in self._stubs: + self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', + request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, + response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, + ) + return self._stubs['list_batch_prediction_jobs'] + + @property + def delete_batch_prediction_job(self) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete batch prediction job method over gRPC. + + Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Returns: + Callable[[~.DeleteBatchPredictionJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_batch_prediction_job' not in self._stubs: + self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', + request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_batch_prediction_job'] + + @property + def cancel_batch_prediction_job(self) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], + Awaitable[empty.Empty]]: + r"""Return a callable for the cancel batch prediction job method over gRPC. + + Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Returns: + Callable[[~.CancelBatchPredictionJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'cancel_batch_prediction_job' not in self._stubs: + self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', + request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['cancel_batch_prediction_job'] + + +__all__ = ( + 'JobServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py new file mode 100644 index 0000000000..c533a12b45 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import MigrationServiceClient +from .async_client import MigrationServiceAsyncClient + +__all__ = ( + 'MigrationServiceClient', + 'MigrationServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py new file mode 100644 index 0000000000..dea3b50632 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.migration_service import pagers +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service + +from .transports.base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import MigrationServiceGrpcAsyncIOTransport +from .client import MigrationServiceClient + + +class MigrationServiceAsyncClient: + """A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + """ + + _client: MigrationServiceClient + + DEFAULT_ENDPOINT = MigrationServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT + + annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) + parse_annotated_dataset_path = staticmethod(MigrationServiceClient.parse_annotated_dataset_path) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + model_path = staticmethod(MigrationServiceClient.model_path) + parse_model_path = staticmethod(MigrationServiceClient.parse_model_path) + model_path = staticmethod(MigrationServiceClient.model_path) + parse_model_path = staticmethod(MigrationServiceClient.parse_model_path) + version_path = staticmethod(MigrationServiceClient.version_path) + parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) + + common_billing_account_path = staticmethod(MigrationServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(MigrationServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(MigrationServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(MigrationServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(MigrationServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(MigrationServiceClient.common_project_path) + parse_common_project_path = staticmethod(MigrationServiceClient.parse_common_project_path) + + common_location_path = staticmethod(MigrationServiceClient.common_location_path) + parse_common_location_path = staticmethod(MigrationServiceClient.parse_common_location_path) + + from_service_account_file = MigrationServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MigrationServiceTransport: + """Return the transport used by the client instance. + + Returns: + MigrationServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the migration service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.MigrationServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = MigrationServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def search_migratable_resources(self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: + r"""Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Args: + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The request object. Request message for + ``MigrationService.SearchMigratableResources``. + parent (:class:`str`): + Required. The location that the migratable resources + should be searched from. It's the AI Platform location + that the resources can be migrated to, not the + resources' original location. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.SearchMigratableResourcesAsyncPager: + Response message for + ``MigrationService.SearchMigratableResources``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = migration_service.SearchMigratableResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.search_migratable_resources, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.SearchMigratableResourcesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_migrate_resources(self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Args: + request (:class:`~.migration_service.BatchMigrateResourcesRequest`): + The request object. Request message for + ``MigrationService.BatchMigrateResources``. + parent (:class:`str`): + Required. The location of the migrated resource will + live in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + migrate_resource_requests (:class:`Sequence[~.migration_service.MigrateResourceRequest]`): + Required. The request messages + specifying the resources to migrate. + They must be in the same location as the + destination. Up to 50 resources can be + migrated in one batch. + This corresponds to the ``migrate_resource_requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.migration_service.BatchMigrateResourcesResponse`: + Response message for + ``MigrationService.BatchMigrateResources``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, migrate_resource_requests]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = migration_service.BatchMigrateResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if migrate_resource_requests is not None: + request.migrate_resource_requests = migrate_resource_requests + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_migrate_resources, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + migration_service.BatchMigrateResourcesResponse, + metadata_type=migration_service.BatchMigrateResourcesOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'MigrationServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py new file mode 100644 index 0000000000..2acb4a0ac7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -0,0 +1,607 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.migration_service import pagers +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service + +from .transports.base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import MigrationServiceGrpcTransport +from .transports.grpc_asyncio import MigrationServiceGrpcAsyncIOTransport + + +class MigrationServiceClientMeta(type): + """Metaclass for the MigrationService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry['grpc'] = MigrationServiceGrpcTransport + _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, + label: str = None, + ) -> Type[MigrationServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class MigrationServiceClient(metaclass=MigrationServiceClientMeta): + """A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + {@api.name}: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MigrationServiceTransport: + """Return the transport used by the client instance. + + Returns: + MigrationServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: + """Return a fully-qualified annotated_dataset string.""" + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + + @staticmethod + def parse_annotated_dataset_path(path: str) -> Dict[str,str]: + """Parse a annotated_dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str,location: str,dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str,str]: + """Parse a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str,location: str,dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str,str]: + """Parse a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str,dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str,str]: + """Parse a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str,location: str,model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parse a model path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str,location: str,model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parse a model path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def version_path(project: str,model: str,version: str,) -> str: + """Return a fully-qualified version string.""" + return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + + @staticmethod + def parse_version_path(path: str) -> Dict[str,str]: + """Parse a version path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the migration service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.MigrationServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, MigrationServiceTransport): + # transport is a MigrationServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def search_migratable_resources(self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: + r"""Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Args: + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The request object. Request message for + ``MigrationService.SearchMigratableResources``. + parent (:class:`str`): + Required. The location that the migratable resources + should be searched from. It's the AI Platform location + that the resources can be migrated to, not the + resources' original location. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.SearchMigratableResourcesPager: + Response message for + ``MigrationService.SearchMigratableResources``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a migration_service.SearchMigratableResourcesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, migration_service.SearchMigratableResourcesRequest): + request = migration_service.SearchMigratableResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.SearchMigratableResourcesPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_migrate_resources(self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Args: + request (:class:`~.migration_service.BatchMigrateResourcesRequest`): + The request object. Request message for + ``MigrationService.BatchMigrateResources``. + parent (:class:`str`): + Required. The location of the migrated resource will + live in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + migrate_resource_requests (:class:`Sequence[~.migration_service.MigrateResourceRequest]`): + Required. The request messages + specifying the resources to migrate. + They must be in the same location as the + destination. Up to 50 resources can be + migrated in one batch. + This corresponds to the ``migrate_resource_requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.migration_service.BatchMigrateResourcesResponse`: + Response message for + ``MigrationService.BatchMigrateResources``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, migrate_resource_requests]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a migration_service.BatchMigrateResourcesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, migration_service.BatchMigrateResourcesRequest): + request = migration_service.BatchMigrateResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + if migrate_resource_requests: + request.migrate_resource_requests.extend(migrate_resource_requests) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_migrate_resources] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + migration_service.BatchMigrateResourcesResponse, + metadata_type=migration_service.BatchMigrateResourcesOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'MigrationServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py new file mode 100644 index 0000000000..cc52903d15 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service + + +class SearchMigratableResourcesPager: + """A pager for iterating through ``search_migratable_resources`` requests. + + This class thinly wraps an initial + :class:`~.migration_service.SearchMigratableResourcesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``migratable_resources`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``SearchMigratableResources`` requests and continue to iterate + through the ``migratable_resources`` field on the + corresponding responses. + + All the usual :class:`~.migration_service.SearchMigratableResourcesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The initial request object. + response (:class:`~.migration_service.SearchMigratableResourcesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = migration_service.SearchMigratableResourcesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[migration_service.SearchMigratableResourcesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: + for page in self.pages: + yield from page.migratable_resources + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class SearchMigratableResourcesAsyncPager: + """A pager for iterating through ``search_migratable_resources`` requests. + + This class thinly wraps an initial + :class:`~.migration_service.SearchMigratableResourcesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``migratable_resources`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``SearchMigratableResources`` requests and continue to iterate + through the ``migratable_resources`` field on the + corresponding responses. + + All the usual :class:`~.migration_service.SearchMigratableResourcesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[migration_service.SearchMigratableResourcesResponse]], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The initial request object. + response (:class:`~.migration_service.SearchMigratableResourcesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = migration_service.SearchMigratableResourcesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[migratable_resource.MigratableResource]: + async def async_generator(): + async for page in self.pages: + for response in page.migratable_resources: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py new file mode 100644 index 0000000000..e42711db2e --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import MigrationServiceTransport +from .grpc import MigrationServiceGrpcTransport +from .grpc_asyncio import MigrationServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] +_transport_registry['grpc'] = MigrationServiceGrpcTransport +_transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport + + +__all__ = ( + 'MigrationServiceTransport', + 'MigrationServiceGrpcTransport', + 'MigrationServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py new file mode 100644 index 0000000000..e48c2471f6 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class MigrationServiceTransport(abc.ABC): + """Abstract transport class for MigrationService.""" + + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) + + def __init__( + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.search_migratable_resources: gapic_v1.method.wrap_method( + self.search_migratable_resources, + default_timeout=None, + client_info=client_info, + ), + self.batch_migrate_resources: gapic_v1.method.wrap_method( + self.batch_migrate_resources, + default_timeout=None, + client_info=client_info, + ), + + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def search_migratable_resources(self) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse] + ]]: + raise NotImplementedError() + + @property + def batch_migrate_resources(self) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'MigrationServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py new file mode 100644 index 0000000000..bf0e91b721 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MigrationServiceTransport, DEFAULT_CLIENT_INFO + + +class MigrationServiceGrpcTransport(MigrationServiceTransport): + """gRPC backend transport for MigrationService. + + A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + self._stubs = {} # type: Dict[str, Callable] + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optionsl[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def search_migratable_resources(self) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse]: + r"""Return a callable for the search migratable resources method over gRPC. + + Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Returns: + Callable[[~.SearchMigratableResourcesRequest], + ~.SearchMigratableResourcesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'search_migratable_resources' not in self._stubs: + self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', + request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, + response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, + ) + return self._stubs['search_migratable_resources'] + + @property + def batch_migrate_resources(self) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + operations.Operation]: + r"""Return a callable for the batch migrate resources method over gRPC. + + Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Returns: + Callable[[~.BatchMigrateResourcesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_migrate_resources' not in self._stubs: + self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', + request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['batch_migrate_resources'] + + +__all__ = ( + 'MigrationServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..3c12daf987 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import MigrationServiceGrpcTransport + + +class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): + """gRPC AsyncIO backend transport for MigrationService. + + A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def search_migratable_resources(self) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse]]: + r"""Return a callable for the search migratable resources method over gRPC. + + Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Returns: + Callable[[~.SearchMigratableResourcesRequest], + Awaitable[~.SearchMigratableResourcesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'search_migratable_resources' not in self._stubs: + self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', + request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, + response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, + ) + return self._stubs['search_migratable_resources'] + + @property + def batch_migrate_resources(self) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the batch migrate resources method over gRPC. + + Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Returns: + Callable[[~.BatchMigrateResourcesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_migrate_resources' not in self._stubs: + self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', + request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['batch_migrate_resources'] + + +__all__ = ( + 'MigrationServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py index b0d80fcc98..3ee8fc6e9e 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py @@ -16,5 +16,9 @@ # from .client import ModelServiceClient +from .async_client import ModelServiceAsyncClient -__all__ = ("ModelServiceClient",) +__all__ = ( + 'ModelServiceClient', + 'ModelServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py new file mode 100644 index 0000000000..1f35b4a15f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -0,0 +1,1044 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.model_service import pagers +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform_v1beta1.types import model_evaluation +from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice +from google.cloud.aiplatform_v1beta1.types import model_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .client import ModelServiceClient + + +class ModelServiceAsyncClient: + """A service for managing AI Platform's machine learning Models.""" + + _client: ModelServiceClient + + DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(ModelServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(ModelServiceClient.parse_endpoint_path) + model_path = staticmethod(ModelServiceClient.model_path) + parse_model_path = staticmethod(ModelServiceClient.parse_model_path) + model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) + parse_model_evaluation_path = staticmethod(ModelServiceClient.parse_model_evaluation_path) + model_evaluation_slice_path = staticmethod(ModelServiceClient.model_evaluation_slice_path) + parse_model_evaluation_slice_path = staticmethod(ModelServiceClient.parse_model_evaluation_slice_path) + training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) + parse_training_pipeline_path = staticmethod(ModelServiceClient.parse_training_pipeline_path) + + common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(ModelServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(ModelServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(ModelServiceClient.common_project_path) + parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) + + common_location_path = staticmethod(ModelServiceClient.common_location_path) + parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) + + from_service_account_file = ModelServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelServiceTransport: + """Return the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = ModelServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def upload_model(self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Uploads a Model artifact into AI Platform. + + Args: + request (:class:`~.model_service.UploadModelRequest`): + The request object. Request message for + ``ModelService.UploadModel``. + parent (:class:`str`): + Required. The resource name of the Location into which + to upload the Model. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + model (:class:`~.gca_model.Model`): + Required. The Model to create. + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.model_service.UploadModelResponse`: Response + message of + ``ModelService.UploadModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, model]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.UploadModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if model is not None: + request.model = model + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.upload_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + model_service.UploadModelResponse, + metadata_type=model_service.UploadModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_model(self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets a Model. + + Args: + request (:class:`~.model_service.GetModelRequest`): + The request object. Request message for + ``ModelService.GetModel``. + name (:class:`str`): + Required. The name of the Model resource. Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.GetModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_models(self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: + r"""Lists Models in a Location. + + Args: + request (:class:`~.model_service.ListModelsRequest`): + The request object. Request message for + ``ModelService.ListModels``. + parent (:class:`str`): + Required. The resource name of the Location to list the + Models from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListModelsAsyncPager: + Response message for + ``ModelService.ListModels`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.ListModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_models, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_model(self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: + r"""Updates a Model. + + Args: + request (:class:`~.model_service.UpdateModelRequest`): + The request object. Request message for + ``ModelService.UpdateModel``. + model (:class:`~.gca_model.Model`): + Required. The Model which replaces + the resource on the server. + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + + [FieldMask](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#fieldmask). + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_model.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([model, update_mask]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.UpdateModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model is not None: + request.model = model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('model.name', request.model.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_model(self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Args: + request (:class:`~.model_service.DeleteModelRequest`): + The request object. Request message for + ``ModelService.DeleteModel``. + name (:class:`str`): + Required. The name of the Model resource to be deleted. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.DeleteModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def export_model(self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. + + Args: + request (:class:`~.model_service.ExportModelRequest`): + The request object. Request message for + ``ModelService.ExportModel``. + name (:class:`str`): + Required. The resource name of the Model to export. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + output_config (:class:`~.model_service.ExportModelRequest.OutputConfig`): + Required. The desired output location + and configuration. + This corresponds to the ``output_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.model_service.ExportModelResponse`: Response + message of + ``ModelService.ExportModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name, output_config]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.ExportModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if output_config is not None: + request.output_config = output_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + model_service.ExportModelResponse, + metadata_type=model_service.ExportModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_model_evaluation(self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: + r"""Gets a ModelEvaluation. + + Args: + request (:class:`~.model_service.GetModelEvaluationRequest`): + The request object. Request message for + ``ModelService.GetModelEvaluation``. + name (:class:`str`): + Required. The name of the ModelEvaluation resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model_evaluation.ModelEvaluation: + A collection of metrics calculated by + comparing Model's predictions on all of + the test data against annotations from + the test data. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.GetModelEvaluationRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model_evaluation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_model_evaluations(self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: + r"""Lists ModelEvaluations in a Model. + + Args: + request (:class:`~.model_service.ListModelEvaluationsRequest`): + The request object. Request message for + ``ModelService.ListModelEvaluations``. + parent (:class:`str`): + Required. The resource name of the Model to list the + ModelEvaluations from. Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListModelEvaluationsAsyncPager: + Response message for + ``ModelService.ListModelEvaluations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.ListModelEvaluationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_model_evaluations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelEvaluationsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_model_evaluation_slice(self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: + r"""Gets a ModelEvaluationSlice. + + Args: + request (:class:`~.model_service.GetModelEvaluationSliceRequest`): + The request object. Request message for + ``ModelService.GetModelEvaluationSlice``. + name (:class:`str`): + Required. The name of the ModelEvaluationSlice resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model_evaluation_slice.ModelEvaluationSlice: + A collection of metrics calculated by + comparing Model's predictions on a slice + of the test data against ground truth + annotations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.GetModelEvaluationSliceRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model_evaluation_slice, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_model_evaluation_slices(self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: + r"""Lists ModelEvaluationSlices in a ModelEvaluation. + + Args: + request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + The request object. Request message for + ``ModelService.ListModelEvaluationSlices``. + parent (:class:`str`): + Required. The resource name of the ModelEvaluation to + list the ModelEvaluationSlices from. Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListModelEvaluationSlicesAsyncPager: + Response message for + ``ModelService.ListModelEvaluationSlices``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = model_service.ListModelEvaluationSlicesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_model_evaluation_slices, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelEvaluationSlicesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'ModelServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index dab285be4c..cade034da4 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.types import deployed_model_ref from google.cloud.aiplatform_v1beta1.types import explanation @@ -41,8 +48,9 @@ from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import ModelServiceTransport +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import ModelServiceGrpcTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport class ModelServiceClientMeta(type): @@ -52,11 +60,13 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry['grpc'] = ModelServiceGrpcTransport + _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -78,8 +88,38 @@ def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: class ModelServiceClient(metaclass=ModelServiceClientMeta): """A service for managing AI Platform's machine learning Models.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -96,26 +136,138 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file + @property + def transport(self) -> ModelServiceTransport: + """Return the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str,location: str,endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def parse_endpoint_path(path: str) -> Dict[str,str]: + """Parse a endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, - ) -> None: + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parse a model path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: + """Return a fully-qualified model_evaluation string.""" + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + + @staticmethod + def parse_model_evaluation_path(path: str) -> Dict[str,str]: + """Parse a model_evaluation path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: + """Return a fully-qualified model_evaluation_slice string.""" + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + + @staticmethod + def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: + """Parse a model_evaluation_slice path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + """Return a fully-qualified training_pipeline string.""" + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + + @staticmethod + def parse_training_pipeline_path(path: str) -> Dict[str,str]: + """Parse a training_pipeline path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -127,38 +279,107 @@ def __init__( transport (Union[str, ~.ModelServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, ModelServiceTransport): - if credentials: + # transport is a ModelServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: raise ValueError( "When providing a transport instance, " - "provide its credentials directly." + "provide its scopes directly." ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) - def upload_model( - self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def upload_model(self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -198,32 +419,45 @@ def upload_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, model]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.UploadModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if model is not None: - request.model = model + has_flattened_params = any([parent, model]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.UploadModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.UploadModelRequest): + request = model_service.UploadModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if model is not None: + request.model = model # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.upload_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.upload_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -236,15 +470,14 @@ def upload_model( # Done; return the response. return response - def get_model( - self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model(self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -271,47 +504,55 @@ def get_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = model_service.GetModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_model, default_timeout=None, client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_model] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_models( - self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models(self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -344,54 +585,65 @@ def list_models( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = model_service.ListModelsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_models, default_timeout=None, client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_models] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def update_model( - self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model(self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -426,45 +678,57 @@ def update_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([model, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.UpdateModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if model is not None: - request.model = model - if update_mask is not None: - request.update_mask = update_mask + has_flattened_params = any([model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.UpdateModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.UpdateModelRequest): + request = model_service.UpdateModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model is not None: + request.model = model + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('model.name', request.model.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def delete_model( - self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_model(self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -511,30 +775,43 @@ def delete_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = model_service.DeleteModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.DeleteModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.DeleteModelRequest): + request = model_service.DeleteModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -547,16 +824,15 @@ def delete_model( # Done; return the response. return response - def export_model( - self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_model(self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -600,32 +876,45 @@ def export_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, output_config]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.ExportModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - if output_config is not None: - request.output_config = output_config + has_flattened_params = any([name, output_config]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ExportModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ExportModelRequest): + request = model_service.ExportModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if output_config is not None: + request.output_config = output_config # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.export_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.export_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -638,15 +927,14 @@ def export_model( # Done; return the response. return response - def get_model_evaluation( - self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation(self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -679,49 +967,55 @@ def get_model_evaluation( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = model_service.GetModelEvaluationRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelEvaluationRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelEvaluationRequest): + request = model_service.GetModelEvaluationRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_model_evaluation, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_model_evaluations( - self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations(self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -754,55 +1048,64 @@ def list_model_evaluations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = model_service.ListModelEvaluationsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelEvaluationsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelEvaluationsRequest): + request = model_service.ListModelEvaluationsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_model_evaluations, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_model_evaluations] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice( - self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice(self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -835,49 +1138,55 @@ def get_model_evaluation_slice( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = model_service.GetModelEvaluationSliceRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelEvaluationSliceRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelEvaluationSliceRequest): + request = model_service.GetModelEvaluationSliceRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_model_evaluation_slice, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_model_evaluation_slices( - self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices(self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -911,55 +1220,72 @@ def list_model_evaluation_slices( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = model_service.ListModelEvaluationSlicesRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelEvaluationSlicesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): + request = model_service.ListModelEvaluationSlicesRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_model_evaluation_slices, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response + + + + + try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("ModelServiceClient",) +__all__ = ( + 'ModelServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index 4169c27b85..716d790932 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model_evaluation @@ -40,15 +40,12 @@ class ListModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [model_service.ListModelsRequest], model_service.ListModelsResponse - ], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - ): + def __init__(self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -58,10 +55,13 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -71,7 +71,7 @@ def pages(self) -> Iterable[model_service.ListModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[model.Model]: @@ -79,7 +79,70 @@ def __iter__(self) -> Iterable[model.Model]: yield from page.models def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListModelsAsyncPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`~.model_service.ListModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`~.model_service.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.model_service.ListModelsRequest`): + The initial request object. + response (:class:`~.model_service.ListModelsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model.Model]: + async def async_generator(): + async for page in self.pages: + for response in page.models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationsPager: @@ -99,16 +162,12 @@ class ListModelEvaluationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse, - ], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - ): + def __init__(self, + method: Callable[..., model_service.ListModelEvaluationsResponse], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -118,10 +177,13 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelEvaluationsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -131,7 +193,7 @@ def pages(self) -> Iterable[model_service.ListModelEvaluationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: @@ -139,7 +201,70 @@ def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: yield from page.model_evaluations def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListModelEvaluationsAsyncPager: + """A pager for iterating through ``list_model_evaluations`` requests. + + This class thinly wraps an initial + :class:`~.model_service.ListModelEvaluationsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``model_evaluations`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModelEvaluations`` requests and continue to iterate + through the ``model_evaluations`` field on the + corresponding responses. + + All the usual :class:`~.model_service.ListModelEvaluationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.model_service.ListModelEvaluationsRequest`): + The initial request object. + response (:class:`~.model_service.ListModelEvaluationsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model_evaluation.ModelEvaluation]: + async def async_generator(): + async for page in self.pages: + for response in page.model_evaluations: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesPager: @@ -159,16 +284,12 @@ class ListModelEvaluationSlicesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse, - ], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - ): + def __init__(self, + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -178,10 +299,13 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationSlicesRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -191,7 +315,7 @@ def pages(self) -> Iterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: @@ -199,4 +323,67 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: yield from page.model_evaluation_slices def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListModelEvaluationSlicesAsyncPager: + """A pager for iterating through ``list_model_evaluation_slices`` requests. + + This class thinly wraps an initial + :class:`~.model_service.ListModelEvaluationSlicesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``model_evaluation_slices`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModelEvaluationSlices`` requests and continue to iterate + through the ``model_evaluation_slices`` field on the + corresponding responses. + + All the usual :class:`~.model_service.ListModelEvaluationSlicesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationSlicesResponse]], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + The initial request object. + response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationSlicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model_evaluation_slice.ModelEvaluationSlice]: + async def async_generator(): + async for page in self.pages: + for response in page.model_evaluation_slices: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py index 7bbcc75582..89bd6faee0 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import ModelServiceTransport from .grpc import ModelServiceGrpcTransport +from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry['grpc'] = ModelServiceGrpcTransport +_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport __all__ = ( - "ModelServiceTransport", - "ModelServiceGrpcTransport", + 'ModelServiceTransport', + 'ModelServiceGrpcTransport', + 'ModelServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index 53f94ea393..d5f10a9943 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -30,17 +34,32 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class ModelServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -50,97 +69,196 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.upload_model: gapic_v1.method.wrap_method( + self.upload_model, + default_timeout=None, + client_info=client_info, + ), + self.get_model: gapic_v1.method.wrap_method( + self.get_model, + default_timeout=None, + client_info=client_info, + ), + self.list_models: gapic_v1.method.wrap_method( + self.list_models, + default_timeout=None, + client_info=client_info, + ), + self.update_model: gapic_v1.method.wrap_method( + self.update_model, + default_timeout=None, + client_info=client_info, + ), + self.delete_model: gapic_v1.method.wrap_method( + self.delete_model, + default_timeout=None, + client_info=client_info, + ), + self.export_model: gapic_v1.method.wrap_method( + self.export_model, + default_timeout=None, + client_info=client_info, + ), + self.get_model_evaluation: gapic_v1.method.wrap_method( + self.get_model_evaluation, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluations: gapic_v1.method.wrap_method( + self.list_model_evaluations, + default_timeout=None, + client_info=client_info, + ), + self.get_model_evaluation_slice: gapic_v1.method.wrap_method( + self.get_model_evaluation_slice, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluation_slices: gapic_v1.method.wrap_method( + self.list_model_evaluation_slices, + default_timeout=None, + client_info=client_info, + ), + + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property - def upload_model( - self, - ) -> typing.Callable[[model_service.UploadModelRequest], operations.Operation]: - raise NotImplementedError + def upload_model(self) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def get_model( - self, - ) -> typing.Callable[[model_service.GetModelRequest], model.Model]: - raise NotImplementedError + def get_model(self) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[ + model.Model, + typing.Awaitable[model.Model] + ]]: + raise NotImplementedError() @property - def list_models( - self, - ) -> typing.Callable[ - [model_service.ListModelsRequest], model_service.ListModelsResponse - ]: - raise NotImplementedError + def list_models(self) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse] + ]]: + raise NotImplementedError() @property - def update_model( - self, - ) -> typing.Callable[[model_service.UpdateModelRequest], gca_model.Model]: - raise NotImplementedError + def update_model(self) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[ + gca_model.Model, + typing.Awaitable[gca_model.Model] + ]]: + raise NotImplementedError() @property - def delete_model( - self, - ) -> typing.Callable[[model_service.DeleteModelRequest], operations.Operation]: - raise NotImplementedError + def delete_model(self) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def export_model( - self, - ) -> typing.Callable[[model_service.ExportModelRequest], operations.Operation]: - raise NotImplementedError + def export_model(self) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def get_model_evaluation( - self, - ) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation - ]: - raise NotImplementedError + def get_model_evaluation(self) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation] + ]]: + raise NotImplementedError() @property - def list_model_evaluations( - self, - ) -> typing.Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse, - ]: - raise NotImplementedError + def list_model_evaluations(self) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse] + ]]: + raise NotImplementedError() @property - def get_model_evaluation_slice( - self, - ) -> typing.Callable[ - [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice, - ]: - raise NotImplementedError + def get_model_evaluation_slice(self) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice] + ]]: + raise NotImplementedError() @property - def list_model_evaluation_slices( - self, - ) -> typing.Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse, - ]: - raise NotImplementedError + def list_model_evaluation_slices(self) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse] + ]]: + raise NotImplementedError() -__all__ = ("ModelServiceTransport",) +__all__ = ( + 'ModelServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index f83c41e879..255d478e9d 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -30,7 +34,7 @@ from google.cloud.aiplatform_v1beta1.types import model_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import ModelServiceTransport +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO class ModelServiceGrpcTransport(ModelServiceTransport): @@ -45,14 +49,20 @@ class ModelServiceGrpcTransport(ModelServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - channel: grpc.Channel = None - ) -> None: + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -63,29 +73,107 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - **kwargs - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -94,30 +182,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property @@ -128,18 +223,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self.__dict__['operations_client'] @property - def upload_model( - self, - ) -> Callable[[model_service.UploadModelRequest], operations.Operation]: + def upload_model(self) -> Callable[ + [model_service.UploadModelRequest], + operations.Operation]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -154,16 +249,18 @@ def upload_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "upload_model" not in self._stubs: - self._stubs["upload_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", + if 'upload_model' not in self._stubs: + self._stubs['upload_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["upload_model"] + return self._stubs['upload_model'] @property - def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + model.Model]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -178,18 +275,18 @@ def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model" not in self._stubs: - self._stubs["get_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs["get_model"] + return self._stubs['get_model'] @property - def list_models( - self, - ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + model_service.ListModelsResponse]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -204,18 +301,18 @@ def list_models( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_models" not in self._stubs: - self._stubs["list_models"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs["list_models"] + return self._stubs['list_models'] @property - def update_model( - self, - ) -> Callable[[model_service.UpdateModelRequest], gca_model.Model]: + def update_model(self) -> Callable[ + [model_service.UpdateModelRequest], + gca_model.Model]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -230,18 +327,18 @@ def update_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_model" not in self._stubs: - self._stubs["update_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", + if 'update_model' not in self._stubs: + self._stubs['update_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs["update_model"] + return self._stubs['update_model'] @property - def delete_model( - self, - ) -> Callable[[model_service.DeleteModelRequest], operations.Operation]: + def delete_model(self) -> Callable[ + [model_service.DeleteModelRequest], + operations.Operation]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -258,18 +355,18 @@ def delete_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_model" not in self._stubs: - self._stubs["delete_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", + if 'delete_model' not in self._stubs: + self._stubs['delete_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_model"] + return self._stubs['delete_model'] @property - def export_model( - self, - ) -> Callable[[model_service.ExportModelRequest], operations.Operation]: + def export_model(self) -> Callable[ + [model_service.ExportModelRequest], + operations.Operation]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -287,20 +384,18 @@ def export_model( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "export_model" not in self._stubs: - self._stubs["export_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", + if 'export_model' not in self._stubs: + self._stubs['export_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["export_model"] + return self._stubs['export_model'] @property - def get_model_evaluation( - self, - ) -> Callable[ - [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation - ]: + def get_model_evaluation(self) -> Callable[ + [model_service.GetModelEvaluationRequest], + model_evaluation.ModelEvaluation]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -315,21 +410,18 @@ def get_model_evaluation( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation" not in self._stubs: - self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", + if 'get_model_evaluation' not in self._stubs: + self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs["get_model_evaluation"] + return self._stubs['get_model_evaluation'] @property - def list_model_evaluations( - self, - ) -> Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse, - ]: + def list_model_evaluations(self) -> Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -344,21 +436,18 @@ def list_model_evaluations( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluations" not in self._stubs: - self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", + if 'list_model_evaluations' not in self._stubs: + self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs["list_model_evaluations"] + return self._stubs['list_model_evaluations'] @property - def get_model_evaluation_slice( - self, - ) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice, - ]: + def get_model_evaluation_slice(self) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + model_evaluation_slice.ModelEvaluationSlice]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -373,21 +462,18 @@ def get_model_evaluation_slice( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_model_evaluation_slice" not in self._stubs: - self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", + if 'get_model_evaluation_slice' not in self._stubs: + self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs["get_model_evaluation_slice"] + return self._stubs['get_model_evaluation_slice'] @property - def list_model_evaluation_slices( - self, - ) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse, - ]: + def list_model_evaluation_slices(self) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -402,13 +488,15 @@ def list_model_evaluation_slices( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_model_evaluation_slices" not in self._stubs: - self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", + if 'list_model_evaluation_slices' not in self._stubs: + self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs["list_model_evaluation_slices"] + return self._stubs['list_model_evaluation_slices'] -__all__ = ("ModelServiceGrpcTransport",) +__all__ = ( + 'ModelServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..850a476d8f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -0,0 +1,507 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform_v1beta1.types import model_evaluation +from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice +from google.cloud.aiplatform_v1beta1.types import model_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import ModelServiceGrpcTransport + + +class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): + """gRPC AsyncIO backend transport for ModelService. + + A service for managing AI Platform's machine learning Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def upload_model(self) -> Callable[ + [model_service.UploadModelRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the upload model method over gRPC. + + Uploads a Model artifact into AI Platform. + + Returns: + Callable[[~.UploadModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'upload_model' not in self._stubs: + self._stubs['upload_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', + request_serializer=model_service.UploadModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['upload_model'] + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + Awaitable[model.Model]]: + r"""Return a callable for the get model method over gRPC. + + Gets a Model. + + Returns: + Callable[[~.GetModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs['get_model'] + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + Awaitable[model_service.ListModelsResponse]]: + r"""Return a callable for the list models method over gRPC. + + Lists Models in a Location. + + Returns: + Callable[[~.ListModelsRequest], + Awaitable[~.ListModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs['list_models'] + + @property + def update_model(self) -> Callable[ + [model_service.UpdateModelRequest], + Awaitable[gca_model.Model]]: + r"""Return a callable for the update model method over gRPC. + + Updates a Model. + + Returns: + Callable[[~.UpdateModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_model' not in self._stubs: + self._stubs['update_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', + request_serializer=model_service.UpdateModelRequest.serialize, + response_deserializer=gca_model.Model.deserialize, + ) + return self._stubs['update_model'] + + @property + def delete_model(self) -> Callable[ + [model_service.DeleteModelRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete model method over gRPC. + + Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Returns: + Callable[[~.DeleteModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_model' not in self._stubs: + self._stubs['delete_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', + request_serializer=model_service.DeleteModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_model'] + + @property + def export_model(self) -> Callable[ + [model_service.ExportModelRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the export model method over gRPC. + + Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. + + Returns: + Callable[[~.ExportModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'export_model' not in self._stubs: + self._stubs['export_model'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', + request_serializer=model_service.ExportModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['export_model'] + + @property + def get_model_evaluation(self) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation]]: + r"""Return a callable for the get model evaluation method over gRPC. + + Gets a ModelEvaluation. + + Returns: + Callable[[~.GetModelEvaluationRequest], + Awaitable[~.ModelEvaluation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model_evaluation' not in self._stubs: + self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', + request_serializer=model_service.GetModelEvaluationRequest.serialize, + response_deserializer=model_evaluation.ModelEvaluation.deserialize, + ) + return self._stubs['get_model_evaluation'] + + @property + def list_model_evaluations(self) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse]]: + r"""Return a callable for the list model evaluations method over gRPC. + + Lists ModelEvaluations in a Model. + + Returns: + Callable[[~.ListModelEvaluationsRequest], + Awaitable[~.ListModelEvaluationsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_model_evaluations' not in self._stubs: + self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', + request_serializer=model_service.ListModelEvaluationsRequest.serialize, + response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, + ) + return self._stubs['list_model_evaluations'] + + @property + def get_model_evaluation_slice(self) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice]]: + r"""Return a callable for the get model evaluation slice method over gRPC. + + Gets a ModelEvaluationSlice. + + Returns: + Callable[[~.GetModelEvaluationSliceRequest], + Awaitable[~.ModelEvaluationSlice]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model_evaluation_slice' not in self._stubs: + self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', + request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, + response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, + ) + return self._stubs['get_model_evaluation_slice'] + + @property + def list_model_evaluation_slices(self) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse]]: + r"""Return a callable for the list model evaluation slices method over gRPC. + + Lists ModelEvaluationSlices in a ModelEvaluation. + + Returns: + Callable[[~.ListModelEvaluationSlicesRequest], + Awaitable[~.ListModelEvaluationSlicesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_model_evaluation_slices' not in self._stubs: + self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', + request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, + response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, + ) + return self._stubs['list_model_evaluation_slices'] + + +__all__ = ( + 'ModelServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py index 29ba95e15f..f7f4d9b9ac 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py @@ -16,5 +16,9 @@ # from .client import PipelineServiceClient +from .async_client import PipelineServiceAsyncClient -__all__ = ("PipelineServiceClient",) +__all__ = ( + 'PipelineServiceClient', + 'PipelineServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py new file mode 100644 index 0000000000..6035fc4277 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -0,0 +1,598 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_service +from google.cloud.aiplatform_v1beta1.types import pipeline_state +from google.cloud.aiplatform_v1beta1.types import training_pipeline +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + +from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport +from .client import PipelineServiceClient + + +class PipelineServiceAsyncClient: + """A service for creating and managing AI Platform's pipelines.""" + + _client: PipelineServiceClient + + DEFAULT_ENDPOINT = PipelineServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PipelineServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(PipelineServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(PipelineServiceClient.parse_endpoint_path) + model_path = staticmethod(PipelineServiceClient.model_path) + parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) + training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) + parse_training_pipeline_path = staticmethod(PipelineServiceClient.parse_training_pipeline_path) + + common_billing_account_path = staticmethod(PipelineServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(PipelineServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(PipelineServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(PipelineServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(PipelineServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(PipelineServiceClient.common_project_path) + parse_common_project_path = staticmethod(PipelineServiceClient.parse_common_project_path) + + common_location_path = staticmethod(PipelineServiceClient.common_location_path) + parse_common_location_path = staticmethod(PipelineServiceClient.parse_common_location_path) + + from_service_account_file = PipelineServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PipelineServiceTransport: + """Return the transport used by the client instance. + + Returns: + PipelineServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the pipeline service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.PipelineServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = PipelineServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_training_pipeline(self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: + r"""Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Args: + request (:class:`~.pipeline_service.CreateTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.CreateTrainingPipeline``. + parent (:class:`str`): + Required. The resource name of the Location to create + the TrainingPipeline in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + training_pipeline (:class:`~.gca_training_pipeline.TrainingPipeline`): + Required. The TrainingPipeline to + create. + This corresponds to the ``training_pipeline`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_training_pipeline.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with + training a Model. It always executes the training task, + and optionally may also export data from AI Platform's + Dataset which becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, training_pipeline]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = pipeline_service.CreateTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if training_pipeline is not None: + request.training_pipeline = training_pipeline + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_training_pipeline(self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: + r"""Gets a TrainingPipeline. + + Args: + request (:class:`~.pipeline_service.GetTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.GetTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline resource. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.training_pipeline.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with + training a Model. It always executes the training task, + and optionally may also export data from AI Platform's + Dataset which becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = pipeline_service.GetTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_training_pipelines(self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: + r"""Lists TrainingPipelines in a Location. + + Args: + request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + The request object. Request message for + ``PipelineService.ListTrainingPipelines``. + parent (:class:`str`): + Required. The resource name of the Location to list the + TrainingPipelines from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListTrainingPipelinesAsyncPager: + Response message for + ``PipelineService.ListTrainingPipelines`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = pipeline_service.ListTrainingPipelinesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_training_pipelines, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTrainingPipelinesAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_training_pipeline(self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a TrainingPipeline. + + Args: + request (:class:`~.pipeline_service.DeleteTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.DeleteTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = pipeline_service.DeleteTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_training_pipeline(self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`~.pipeline_service.CancelTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.CancelTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline to cancel. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = pipeline_service.CancelTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'PipelineServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 2530414b9a..fbecd0dc70 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -16,33 +16,39 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore -from .transports.base import PipelineServiceTransport +from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import PipelineServiceGrpcTransport +from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport class PipelineServiceClientMeta(type): @@ -52,13 +58,13 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry['grpc'] = PipelineServiceGrpcTransport + _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry["grpc"] = PipelineServiceGrpcTransport - - def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -80,8 +86,38 @@ def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTranspor class PipelineServiceClient(metaclass=PipelineServiceClientMeta): """A service for creating and managing AI Platform's pipelines.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -98,35 +134,116 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file + @property + def transport(self) -> PipelineServiceTransport: + """Return the transport used by the client instance. + + Returns: + PipelineServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod - def model_path(project: str, location: str, model: str,) -> str: + def endpoint_path(project: str,location: str,endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str,str]: + """Parse a endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str,location: str,model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) @staticmethod - def training_pipeline_path( - project: str, location: str, training_pipeline: str, - ) -> str: + def parse_model_path(path: str) -> Dict[str,str]: + """Parse a model path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + + @staticmethod + def parse_training_pipeline_path(path: str) -> Dict[str,str]: + """Parse a training_pipeline path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, - ) -> None: + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -138,38 +255,107 @@ def __init__( transport (Union[str, ~.PipelineServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PipelineServiceTransport): - if credentials: + # transport is a PipelineServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: raise ValueError( "When providing a transport instance, " - "provide its credentials directly." + "provide its scopes directly." ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) - def create_training_pipeline( - self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline(self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -210,45 +396,57 @@ def create_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, training_pipeline]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = pipeline_service.CreateTrainingPipelineRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if training_pipeline is not None: - request.training_pipeline = training_pipeline + has_flattened_params = any([parent, training_pipeline]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CreateTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): + request = pipeline_service.CreateTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if training_pipeline is not None: + request.training_pipeline = training_pipeline # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_training_pipeline, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def get_training_pipeline( - self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline(self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -283,49 +481,55 @@ def get_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = pipeline_service.GetTrainingPipelineRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.GetTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): + request = pipeline_service.GetTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_training_pipeline, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_training_pipeline] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_training_pipelines( - self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines(self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -358,55 +562,64 @@ def list_training_pipelines( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = pipeline_service.ListTrainingPipelinesRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.ListTrainingPipelinesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): + request = pipeline_service.ListTrainingPipelinesRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_training_pipelines, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_training_pipelines] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline( - self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_training_pipeline(self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -452,30 +665,43 @@ def delete_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = pipeline_service.DeleteTrainingPipelineRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.DeleteTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): + request = pipeline_service.DeleteTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_training_pipeline, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -488,15 +714,14 @@ def delete_training_pipeline( # Done; return the response. return response - def cancel_training_pipeline( - self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline(self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -533,42 +758,60 @@ def cancel_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = pipeline_service.CancelTrainingPipelineRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CancelTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): + request = pipeline_service.CancelTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_training_pipeline, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. rpc( - request, retry=retry, timeout=timeout, metadata=metadata, + request, + retry=retry, + timeout=timeout, + metadata=metadata, ) + + + + + try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PipelineServiceClient",) +__all__ = ( + 'PipelineServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index 0db54250ef..beee148035 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -38,16 +38,12 @@ class ListTrainingPipelinesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse, - ], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - ): + def __init__(self, + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -57,10 +53,13 @@ def __init__( The initial request object. response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListTrainingPipelinesRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -70,7 +69,7 @@ def pages(self) -> Iterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: @@ -78,4 +77,67 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: yield from page.training_pipelines def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListTrainingPipelinesAsyncPager: + """A pager for iterating through ``list_training_pipelines`` requests. + + This class thinly wraps an initial + :class:`~.pipeline_service.ListTrainingPipelinesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``training_pipelines`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTrainingPipelines`` requests and continue to iterate + through the ``training_pipelines`` field on the + corresponding responses. + + All the usual :class:`~.pipeline_service.ListTrainingPipelinesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[pipeline_service.ListTrainingPipelinesResponse]], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + The initial request object. + response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = pipeline_service.ListTrainingPipelinesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[training_pipeline.TrainingPipeline]: + async def async_generator(): + async for page in self.pages: + for response in page.training_pipelines: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py index 615b2c1025..3caa4c7906 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import PipelineServiceTransport from .grpc import PipelineServiceGrpcTransport +from .grpc_asyncio import PipelineServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] -_transport_registry["grpc"] = PipelineServiceGrpcTransport +_transport_registry['grpc'] = PipelineServiceGrpcTransport +_transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport __all__ = ( - "PipelineServiceTransport", - "PipelineServiceGrpcTransport", + 'PipelineServiceTransport', + 'PipelineServiceGrpcTransport', + 'PipelineServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 5696ede4d7..0a74b8e8b6 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -17,31 +17,48 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -class PipelineServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -51,65 +68,126 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_training_pipeline: gapic_v1.method.wrap_method( + self.create_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.get_training_pipeline: gapic_v1.method.wrap_method( + self.get_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.list_training_pipelines: gapic_v1.method.wrap_method( + self.list_training_pipelines, + default_timeout=None, + client_info=client_info, + ), + self.delete_training_pipeline: gapic_v1.method.wrap_method( + self.delete_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.cancel_training_pipeline: gapic_v1.method.wrap_method( + self.cancel_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property - def create_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline, - ]: - raise NotImplementedError + def create_training_pipeline(self) -> typing.Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline] + ]]: + raise NotImplementedError() @property - def get_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline, - ]: - raise NotImplementedError + def get_training_pipeline(self) -> typing.Callable[ + [pipeline_service.GetTrainingPipelineRequest], + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline] + ]]: + raise NotImplementedError() @property - def list_training_pipelines( - self, - ) -> typing.Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse, - ]: - raise NotImplementedError + def list_training_pipelines(self) -> typing.Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ]]: + raise NotImplementedError() @property - def delete_training_pipeline( - self, - ) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation - ]: - raise NotImplementedError + def delete_training_pipeline(self) -> typing.Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def cancel_training_pipeline( - self, - ) -> typing.Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: - raise NotImplementedError + def cancel_training_pipeline(self) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[ + empty.Empty, + typing.Awaitable[empty.Empty] + ]]: + raise NotImplementedError() -__all__ = ("PipelineServiceTransport",) +__all__ = ( + 'PipelineServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 7ce95caab7..096505204b 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -15,23 +15,25 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import PipelineServiceTransport +from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO class PipelineServiceGrpcTransport(PipelineServiceTransport): @@ -46,14 +48,20 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - channel: grpc.Channel = None - ) -> None: + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -64,29 +72,107 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - **kwargs - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -95,30 +181,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property @@ -129,21 +222,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self.__dict__['operations_client'] @property - def create_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline, - ]: + def create_training_pipeline(self) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + gca_training_pipeline.TrainingPipeline]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -159,21 +249,18 @@ def create_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_training_pipeline" not in self._stubs: - self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", + if 'create_training_pipeline' not in self._stubs: + self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["create_training_pipeline"] + return self._stubs['create_training_pipeline'] @property - def get_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline, - ]: + def get_training_pipeline(self) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + training_pipeline.TrainingPipeline]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -188,21 +275,18 @@ def get_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_training_pipeline" not in self._stubs: - self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", + if 'get_training_pipeline' not in self._stubs: + self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs["get_training_pipeline"] + return self._stubs['get_training_pipeline'] @property - def list_training_pipelines( - self, - ) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse, - ]: + def list_training_pipelines(self) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -217,20 +301,18 @@ def list_training_pipelines( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_training_pipelines" not in self._stubs: - self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", + if 'list_training_pipelines' not in self._stubs: + self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs["list_training_pipelines"] + return self._stubs['list_training_pipelines'] @property - def delete_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation - ]: + def delete_training_pipeline(self) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + operations.Operation]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -245,18 +327,18 @@ def delete_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_training_pipeline" not in self._stubs: - self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", + if 'delete_training_pipeline' not in self._stubs: + self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_training_pipeline"] + return self._stubs['delete_training_pipeline'] @property - def cancel_training_pipeline( - self, - ) -> Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: + def cancel_training_pipeline(self) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + empty.Empty]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -283,13 +365,15 @@ def cancel_training_pipeline( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "cancel_training_pipeline" not in self._stubs: - self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", + if 'cancel_training_pipeline' not in self._stubs: + self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs["cancel_training_pipeline"] + return self._stubs['cancel_training_pipeline'] -__all__ = ("PipelineServiceGrpcTransport",) +__all__ = ( + 'PipelineServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..ce9bc0c191 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -0,0 +1,384 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import pipeline_service +from google.cloud.aiplatform_v1beta1.types import training_pipeline +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import PipelineServiceGrpcTransport + + +class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): + """gRPC AsyncIO backend transport for PipelineService. + + A service for creating and managing AI Platform's pipelines. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def create_training_pipeline(self) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline]]: + r"""Return a callable for the create training pipeline method over gRPC. + + Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Returns: + Callable[[~.CreateTrainingPipelineRequest], + Awaitable[~.TrainingPipeline]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_training_pipeline' not in self._stubs: + self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', + request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, + response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs['create_training_pipeline'] + + @property + def get_training_pipeline(self) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline]]: + r"""Return a callable for the get training pipeline method over gRPC. + + Gets a TrainingPipeline. + + Returns: + Callable[[~.GetTrainingPipelineRequest], + Awaitable[~.TrainingPipeline]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_training_pipeline' not in self._stubs: + self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', + request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, + response_deserializer=training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs['get_training_pipeline'] + + @property + def list_training_pipelines(self) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse]]: + r"""Return a callable for the list training pipelines method over gRPC. + + Lists TrainingPipelines in a Location. + + Returns: + Callable[[~.ListTrainingPipelinesRequest], + Awaitable[~.ListTrainingPipelinesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_training_pipelines' not in self._stubs: + self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', + request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, + response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, + ) + return self._stubs['list_training_pipelines'] + + @property + def delete_training_pipeline(self) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete training pipeline method over gRPC. + + Deletes a TrainingPipeline. + + Returns: + Callable[[~.DeleteTrainingPipelineRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_training_pipeline' not in self._stubs: + self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', + request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_training_pipeline'] + + @property + def cancel_training_pipeline(self) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + Awaitable[empty.Empty]]: + r"""Return a callable for the cancel training pipeline method over gRPC. + + Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelTrainingPipelineRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'cancel_training_pipeline' not in self._stubs: + self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', + request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs['cancel_training_pipeline'] + + +__all__ = ( + 'PipelineServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py index 9e3af89360..d4047c335d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py @@ -16,5 +16,9 @@ # from .client import PredictionServiceClient +from .async_client import PredictionServiceAsyncClient -__all__ = ("PredictionServiceClient",) +__all__ = ( + 'PredictionServiceClient', + 'PredictionServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py new file mode 100644 index 0000000000..283bb73f3e --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.protobuf import struct_pb2 as struct # type: ignore + +from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport +from .client import PredictionServiceClient + + +class PredictionServiceAsyncClient: + """A service for online predictions and explanations.""" + + _client: PredictionServiceClient + + DEFAULT_ENDPOINT = PredictionServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) + + common_billing_account_path = staticmethod(PredictionServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(PredictionServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(PredictionServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(PredictionServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(PredictionServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(PredictionServiceClient.common_project_path) + parse_common_project_path = staticmethod(PredictionServiceClient.parse_common_project_path) + + common_location_path = staticmethod(PredictionServiceClient.common_location_path) + parse_common_location_path = staticmethod(PredictionServiceClient.parse_common_location_path) + + from_service_account_file = PredictionServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PredictionServiceTransport: + """Return the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the prediction service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.PredictionServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = PredictionServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def predict(self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: + r"""Perform an online prediction. + + Args: + request (:class:`~.prediction_service.PredictRequest`): + The request object. Request message for + ``PredictionService.Predict``. + endpoint (:class:`str`): + Required. The name of the Endpoint requested to serve + the prediction. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (:class:`Sequence[~.struct.Value]`): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + parameters (:class:`~.struct.Value`): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.prediction_service.PredictResponse: + Response message for + ``PredictionService.Predict``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([endpoint, instances, parameters]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = prediction_service.PredictRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if instances is not None: + request.instances = instances + if parameters is not None: + request.parameters = parameters + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.predict, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def explain(self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: + r"""Perform an online explanation. + + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have + ``explanation_spec`` + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have + ``explanation_spec`` + populated. Only deployed AutoML tabular Models have + explanation_spec. + + Args: + request (:class:`~.prediction_service.ExplainRequest`): + The request object. Request message for + ``PredictionService.Explain``. + endpoint (:class:`str`): + Required. The name of the Endpoint requested to serve + the explanation. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (:class:`Sequence[~.struct.Value]`): + Required. The instances that are the input to the + explanation call. A DeployedModel may have an upper + limit on the number of instances it supports per + request, and when it is exceeded the explanation call + errors in case of AutoML Models, or, in case of customer + created Models, the behaviour is as documented by that + Model. The schema of any single instance may be + specified via Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + parameters (:class:`~.struct.Value`): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (:class:`str`): + If specified, this ExplainRequest will be served by the + chosen DeployedModel, overriding + ``Endpoint.traffic_split``. + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.prediction_service.ExplainResponse: + Response message for + ``PredictionService.Explain``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([endpoint, instances, parameters, deployed_model_id]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = prediction_service.ExplainRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if instances is not None: + request.instances = instances + if parameters is not None: + request.parameters = parameters + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.explain, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'PredictionServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index dbdf226471..2627b20ae3 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -16,22 +16,29 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service from google.protobuf import struct_pb2 as struct # type: ignore -from .transports.base import PredictionServiceTransport +from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import PredictionServiceGrpcTransport +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport class PredictionServiceClientMeta(type): @@ -41,15 +48,13 @@ class PredictionServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry['grpc'] = PredictionServiceGrpcTransport + _transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[PredictionServiceTransport]] - _transport_registry["grpc"] = PredictionServiceGrpcTransport - - def get_transport_class( - cls, label: str = None, - ) -> Type[PredictionServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[PredictionServiceTransport]: """Return an appropriate transport class. Args: @@ -71,8 +76,38 @@ def get_transport_class( class PredictionServiceClient(metaclass=PredictionServiceClientMeta): """A service for online predictions and explanations.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -89,19 +124,94 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, - ) -> None: + @property + def transport(self) -> PredictionServiceTransport: + """Return the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str,location: str,endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str,str]: + """Parse a endpoint path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -113,39 +223,108 @@ def __init__( transport (Union[str, ~.PredictionServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PredictionServiceTransport): - if credentials: + # transport is a PredictionServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: raise ValueError( "When providing a transport instance, " - "provide its credentials directly." + "provide its scopes directly." ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) - def predict( - self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + def predict(self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -200,55 +379,72 @@ def predict( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, instances, parameters]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([endpoint, instances, parameters]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a prediction_service.PredictRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, prediction_service.PredictRequest): + request = prediction_service.PredictRequest(request) - request = prediction_service.PredictRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. - # If we have keyword arguments corresponding to fields on the - # request, apply these. + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances.extend(instances) - if parameters is not None: - request.parameters = parameters + if instances: + request.instances.extend(instances) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.predict, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.predict] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def explain( - self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: + def explain(self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. - If [ExplainRequest.deployed_model_id] is specified, the - corresponding DeployModel must have + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have ``explanation_spec`` - populated. If [ExplainRequest.deployed_model_id] is not - specified, all DeployedModels must have + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have ``explanation_spec`` populated. Only deployed AutoML tabular Models have explanation_spec. @@ -312,49 +508,70 @@ def explain( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any( - [endpoint, instances, parameters, deployed_model_id] - ): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = prediction_service.ExplainRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances.extend(instances) - if parameters is not None: - request.parameters = parameters - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id + has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a prediction_service.ExplainRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, prediction_service.ExplainRequest): + request = prediction_service.ExplainRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if instances: + request.instances.extend(instances) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.explain, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.explain] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('endpoint', request.endpoint), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response + + + + + try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("PredictionServiceClient",) +__all__ = ( + 'PredictionServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py index 33eefca757..e130201fdf 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import PredictionServiceTransport from .grpc import PredictionServiceGrpcTransport +from .grpc_asyncio import PredictionServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] -_transport_registry["grpc"] = PredictionServiceGrpcTransport +_transport_registry['grpc'] = PredictionServiceGrpcTransport +_transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport __all__ = ( - "PredictionServiceTransport", - "PredictionServiceGrpcTransport", + 'PredictionServiceTransport', + 'PredictionServiceGrpcTransport', + 'PredictionServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 58d508474a..86e2292130 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -17,24 +17,43 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service -class PredictionServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -44,35 +63,79 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.predict: gapic_v1.method.wrap_method( + self.predict, + default_timeout=None, + client_info=client_info, + ), + self.explain: gapic_v1.method.wrap_method( + self.explain, + default_timeout=None, + client_info=client_info, + ), + + } + @property - def predict( - self, - ) -> typing.Callable[ - [prediction_service.PredictRequest], prediction_service.PredictResponse - ]: - raise NotImplementedError + def predict(self) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse] + ]]: + raise NotImplementedError() @property - def explain( - self, - ) -> typing.Callable[ - [prediction_service.ExplainRequest], prediction_service.ExplainResponse - ]: - raise NotImplementedError + def explain(self) -> typing.Callable[ + [prediction_service.ExplainRequest], + typing.Union[ + prediction_service.ExplainResponse, + typing.Awaitable[prediction_service.ExplainResponse] + ]]: + raise NotImplementedError() -__all__ = ("PredictionServiceTransport",) +__all__ = ( + 'PredictionServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 55824a233c..520120cfa3 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -15,16 +15,20 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service -from .base import PredictionServiceTransport +from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO class PredictionServiceGrpcTransport(PredictionServiceTransport): @@ -39,14 +43,20 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - channel: grpc.Channel = None - ) -> None: + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -57,29 +67,107 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - **kwargs - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -88,38 +176,43 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property - def predict( - self, - ) -> Callable[ - [prediction_service.PredictRequest], prediction_service.PredictResponse - ]: + def predict(self) -> Callable[ + [prediction_service.PredictRequest], + prediction_service.PredictResponse]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -134,29 +227,29 @@ def predict( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "predict" not in self._stubs: - self._stubs["predict"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", + if 'predict' not in self._stubs: + self._stubs['predict'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs["predict"] + return self._stubs['predict'] @property - def explain( - self, - ) -> Callable[ - [prediction_service.ExplainRequest], prediction_service.ExplainResponse - ]: + def explain(self) -> Callable[ + [prediction_service.ExplainRequest], + prediction_service.ExplainResponse]: r"""Return a callable for the explain method over gRPC. Perform an online explanation. - If [ExplainRequest.deployed_model_id] is specified, the - corresponding DeployModel must have + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have ``explanation_spec`` - populated. If [ExplainRequest.deployed_model_id] is not - specified, all DeployedModels must have + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have ``explanation_spec`` populated. Only deployed AutoML tabular Models have explanation_spec. @@ -171,13 +264,15 @@ def explain( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "explain" not in self._stubs: - self._stubs["explain"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", + if 'explain' not in self._stubs: + self._stubs['explain'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', request_serializer=prediction_service.ExplainRequest.serialize, response_deserializer=prediction_service.ExplainResponse.deserialize, ) - return self._stubs["explain"] + return self._stubs['explain'] -__all__ = ("PredictionServiceGrpcTransport",) +__all__ = ( + 'PredictionServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..1a1d48b450 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -0,0 +1,283 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import prediction_service + +from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import PredictionServiceGrpcTransport + + +class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): + """gRPC AsyncIO backend transport for PredictionService. + + A service for online predictions and explanations. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def predict(self) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse]]: + r"""Return a callable for the predict method over gRPC. + + Perform an online prediction. + + Returns: + Callable[[~.PredictRequest], + Awaitable[~.PredictResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'predict' not in self._stubs: + self._stubs['predict'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', + request_serializer=prediction_service.PredictRequest.serialize, + response_deserializer=prediction_service.PredictResponse.deserialize, + ) + return self._stubs['predict'] + + @property + def explain(self) -> Callable[ + [prediction_service.ExplainRequest], + Awaitable[prediction_service.ExplainResponse]]: + r"""Return a callable for the explain method over gRPC. + + Perform an online explanation. + + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have + ``explanation_spec`` + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have + ``explanation_spec`` + populated. Only deployed AutoML tabular Models have + explanation_spec. + + Returns: + Callable[[~.ExplainRequest], + Awaitable[~.ExplainResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'explain' not in self._stubs: + self._stubs['explain'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', + request_serializer=prediction_service.ExplainRequest.serialize, + response_deserializer=prediction_service.ExplainResponse.deserialize, + ) + return self._stubs['explain'] + + +__all__ = ( + 'PredictionServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py index 8f429cd5eb..e4247d7758 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py @@ -16,5 +16,9 @@ # from .client import SpecialistPoolServiceClient +from .async_client import SpecialistPoolServiceAsyncClient -__all__ = ("SpecialistPoolServiceClient",) +__all__ = ( + 'SpecialistPoolServiceClient', + 'SpecialistPoolServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py new file mode 100644 index 0000000000..c4ea8855c1 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -0,0 +1,636 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool_service +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + +from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport +from .client import SpecialistPoolServiceClient + + +class SpecialistPoolServiceAsyncClient: + """A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + """ + + _client: SpecialistPoolServiceClient + + DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT + + specialist_pool_path = staticmethod(SpecialistPoolServiceClient.specialist_pool_path) + parse_specialist_pool_path = staticmethod(SpecialistPoolServiceClient.parse_specialist_pool_path) + + common_billing_account_path = staticmethod(SpecialistPoolServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(SpecialistPoolServiceClient.parse_common_billing_account_path) + + common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(SpecialistPoolServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(SpecialistPoolServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(SpecialistPoolServiceClient.parse_common_organization_path) + + common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) + parse_common_project_path = staticmethod(SpecialistPoolServiceClient.parse_common_project_path) + + common_location_path = staticmethod(SpecialistPoolServiceClient.common_location_path) + parse_common_location_path = staticmethod(SpecialistPoolServiceClient.parse_common_location_path) + + from_service_account_file = SpecialistPoolServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> SpecialistPoolServiceTransport: + """Return the transport used by the client instance. + + Returns: + SpecialistPoolServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(SpecialistPoolServiceClient).get_transport_class, type(SpecialistPoolServiceClient)) + + def __init__(self, *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = 'grpc_asyncio', + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the specialist pool service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.SpecialistPoolServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = SpecialistPoolServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_specialist_pool(self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a SpecialistPool. + + Args: + request (:class:`~.specialist_pool_service.CreateSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.CreateSpecialistPool``. + parent (:class:`str`): + Required. The parent Project name for the new + SpecialistPool. The form is + ``projects/{project}/locations/{location}``. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + Required. The SpecialistPool to + create. + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_specialist_pool.SpecialistPool`: + SpecialistPool represents customers' own workforce to + work on their data labeling jobs. It includes a group of + specialist managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent, specialist_pool]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = specialist_pool_service.CreateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if specialist_pool is not None: + request.specialist_pool = specialist_pool + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.CreateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_specialist_pool(self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: + r"""Gets a SpecialistPool. + + Args: + request (:class:`~.specialist_pool_service.GetSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.GetSpecialistPool``. + name (:class:`str`): + Required. The name of the SpecialistPool resource. The + form is + + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.specialist_pool.SpecialistPool: + SpecialistPool represents customers' + own workforce to work on their data + labeling jobs. It includes a group of + specialist managers who are responsible + for managing the labelers in this pool + as well as customers' data labeling jobs + associated with this pool. + Customers create specialist pool as well + as start data labeling jobs on Cloud, + managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = specialist_pool_service.GetSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_specialist_pools(self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: + r"""Lists SpecialistPools in a Location. + + Args: + request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + The request object. Request message for + ``SpecialistPoolService.ListSpecialistPools``. + parent (:class:`str`): + Required. The name of the SpecialistPool's parent + resource. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListSpecialistPoolsAsyncPager: + Response message for + ``SpecialistPoolService.ListSpecialistPools``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([parent]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = specialist_pool_service.ListSpecialistPoolsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_specialist_pools, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListSpecialistPoolsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_specialist_pool(self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a SpecialistPool as well as all Specialists + in the pool. + + Args: + request (:class:`~.specialist_pool_service.DeleteSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.DeleteSpecialistPool``. + name (:class:`str`): + Required. The resource name of the SpecialistPool to + delete. Format: + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([name]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def update_specialist_pool(self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates a SpecialistPool. + + Args: + request (:class:`~.specialist_pool_service.UpdateSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.UpdateSpecialistPool``. + specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + Required. The SpecialistPool which + replaces the resource on the server. + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to + the resource. + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_specialist_pool.SpecialistPool`: + SpecialistPool represents customers' own workforce to + work on their data labeling jobs. It includes a group of + specialist managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + if request is not None and any([specialist_pool, update_mask]): + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if specialist_pool is not None: + request.specialist_pool = specialist_pool + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('specialist_pool.name', request.specialist_pool.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.UpdateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + + + + + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ( + 'SpecialistPoolServiceAsyncClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index ddc9c26ab9..f6938e8d1f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import specialist_pool @@ -35,8 +42,9 @@ from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from .transports.base import SpecialistPoolServiceTransport +from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import SpecialistPoolServiceGrpcTransport +from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport class SpecialistPoolServiceClientMeta(type): @@ -46,15 +54,13 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport + _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - _transport_registry = ( - OrderedDict() - ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport - - def get_transport_class( - cls, label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + def get_transport_class(cls, + label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -82,8 +88,38 @@ class SpecialistPoolServiceClient(metaclass=SpecialistPoolServiceClientMeta): manage Specialists and tasks on CrowdCompute console. """ - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -100,26 +136,94 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file(filename) - kwargs["credentials"] = credentials + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs['credentials'] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file + @property + def transport(self) -> SpecialistPoolServiceTransport: + """Return the transport used by the client instance. + + Returns: + SpecialistPoolServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod - def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: + def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + + @staticmethod + def parse_specialist_pool_path(path: str) -> Dict[str,str]: + """Parse a specialist_pool path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) + return m.groupdict() if m else {} - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, - ) -> None: + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__(self, *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -131,38 +235,107 @@ def __init__( transport (Union[str, ~.SpecialistPoolServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, SpecialistPoolServiceTransport): - if credentials: + # transport is a SpecialistPoolServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError('When providing a transport instance, ' + 'provide its credentials directly.') + if client_options.scopes: raise ValueError( "When providing a transport instance, " - "provide its credentials directly." + "provide its scopes directly." ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) - def create_specialist_pool( - self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_specialist_pool(self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -208,32 +381,45 @@ def create_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, specialist_pool]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = specialist_pool_service.CreateSpecialistPoolRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if specialist_pool is not None: - request.specialist_pool = specialist_pool + has_flattened_params = any([parent, specialist_pool]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.CreateSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): + request = specialist_pool_service.CreateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if specialist_pool is not None: + request.specialist_pool = specialist_pool # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_specialist_pool, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -246,15 +432,14 @@ def create_specialist_pool( # Done; return the response. return response - def get_specialist_pool( - self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool(self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -294,49 +479,55 @@ def get_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = specialist_pool_service.GetSpecialistPoolRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.GetSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): + request = specialist_pool_service.GetSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_specialist_pool, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_specialist_pool] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Done; return the response. return response - def list_specialist_pools( - self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools(self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -369,55 +560,64 @@ def list_specialist_pools( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = specialist_pool_service.ListSpecialistPoolsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.ListSpecialistPoolsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): + request = specialist_pool_service.ListSpecialistPoolsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_specialist_pools, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_specialist_pools] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', request.parent), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, request=request, response=response, + method=rpc, + request=request, + response=response, + metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool( - self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_specialist_pool(self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -463,30 +663,43 @@ def delete_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') - request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.DeleteSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_specialist_pool, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('name', request.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -499,16 +712,15 @@ def delete_specialist_pool( # Done; return the response. return response - def update_specialist_pool( - self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def update_specialist_pool(self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -553,32 +765,45 @@ def update_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([specialist_pool, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = specialist_pool_service.UpdateSpecialistPoolRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if specialist_pool is not None: - request.specialist_pool = specialist_pool - if update_mask is not None: - request.update_mask = update_mask + has_flattened_params = any([specialist_pool, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.UpdateSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if specialist_pool is not None: + request.specialist_pool = specialist_pool + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_specialist_pool, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('specialist_pool.name', request.specialist_pool.name), + )), ) # Send the request. - response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -592,14 +817,21 @@ def update_specialist_pool( return response + + + + + try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", + 'google-cloud-aiplatform', ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ("SpecialistPoolServiceClient",) +__all__ = ( + 'SpecialistPoolServiceClient', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index 012b76479b..68093dbff5 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service @@ -38,16 +38,12 @@ class ListSpecialistPoolsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - - def __init__( - self, - method: Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse, - ], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - ): + def __init__(self, + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): """Instantiate the pager. Args: @@ -57,10 +53,13 @@ def __init__( The initial request object. response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -70,7 +69,7 @@ def pages(self) -> Iterable[specialist_pool_service.ListSpecialistPoolsResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: @@ -78,4 +77,67 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: yield from page.specialist_pools def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListSpecialistPoolsAsyncPager: + """A pager for iterating through ``list_specialist_pools`` requests. + + This class thinly wraps an initial + :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``specialist_pools`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListSpecialistPools`` requests and continue to iterate + through the ``specialist_pools`` field on the + corresponding responses. + + All the usual :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + The initial request object. + response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[specialist_pool.SpecialistPool]: + async def async_generator(): + async for page in self.pages: + for response in page.specialist_pools: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py index c77d2d31a3..ed5bf01517 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py @@ -20,16 +20,17 @@ from .base import SpecialistPoolServiceTransport from .grpc import SpecialistPoolServiceGrpcTransport +from .grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport # Compile a registry of transports. -_transport_registry = ( - OrderedDict() -) # type: Dict[str, Type[SpecialistPoolServiceTransport]] -_transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport +_transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] +_transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport +_transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( - "SpecialistPoolServiceTransport", - "SpecialistPoolServiceGrpcTransport", + 'SpecialistPoolServiceTransport', + 'SpecialistPoolServiceGrpcTransport', + 'SpecialistPoolServiceGrpcAsyncIOTransport', ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index effe36767e..20c4d1cf3c 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -27,17 +31,32 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class SpecialistPoolServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + 'google-cloud-aiplatform', + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" - AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + AUTH_SCOPES = ( + 'https://www.googleapis.com/auth/cloud-platform', + ) def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - ) -> None: + self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -47,66 +66,126 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ":" not in host: - host += ":443" + if ':' not in host: + host += ':443' self._host = host # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, + scopes=scopes, + quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_specialist_pool: gapic_v1.method.wrap_method( + self.create_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.get_specialist_pool: gapic_v1.method.wrap_method( + self.get_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.list_specialist_pools: gapic_v1.method.wrap_method( + self.list_specialist_pools, + default_timeout=None, + client_info=client_info, + ), + self.delete_specialist_pool: gapic_v1.method.wrap_method( + self.delete_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.update_specialist_pool: gapic_v1.method.wrap_method( + self.update_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property - def create_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation - ]: - raise NotImplementedError + def create_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def get_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool, - ]: - raise NotImplementedError + def get_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool] + ]]: + raise NotImplementedError() @property - def list_specialist_pools( - self, - ) -> typing.Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse, - ]: - raise NotImplementedError + def list_specialist_pools(self) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ]]: + raise NotImplementedError() @property - def delete_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation - ]: - raise NotImplementedError + def delete_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() @property - def update_specialist_pool( - self, - ) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation - ]: - raise NotImplementedError - - -__all__ = ("SpecialistPoolServiceTransport",) + def update_specialist_pool(self) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[ + operations.Operation, + typing.Awaitable[operations.Operation] + ]]: + raise NotImplementedError() + + +__all__ = ( + 'SpecialistPoolServiceTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 92cff5699c..071a58862f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -27,7 +31,7 @@ from google.cloud.aiplatform_v1beta1.types import specialist_pool_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import SpecialistPoolServiceTransport +from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): @@ -47,14 +51,20 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - channel: grpc.Channel = None - ) -> None: + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -65,29 +75,107 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - **kwargs - ) -> grpc.Channel: + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -96,30 +184,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property @@ -130,20 +225,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self.__dict__['operations_client'] @property - def create_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation - ]: + def create_specialist_pool(self) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -158,21 +251,18 @@ def create_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "create_specialist_pool" not in self._stubs: - self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", + if 'create_specialist_pool' not in self._stubs: + self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["create_specialist_pool"] + return self._stubs['create_specialist_pool'] @property - def get_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool, - ]: + def get_specialist_pool(self) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + specialist_pool.SpecialistPool]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -187,21 +277,18 @@ def get_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "get_specialist_pool" not in self._stubs: - self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", + if 'get_specialist_pool' not in self._stubs: + self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs["get_specialist_pool"] + return self._stubs['get_specialist_pool'] @property - def list_specialist_pools( - self, - ) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse, - ]: + def list_specialist_pools(self) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -216,20 +303,18 @@ def list_specialist_pools( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "list_specialist_pools" not in self._stubs: - self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", + if 'list_specialist_pools' not in self._stubs: + self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs["list_specialist_pools"] + return self._stubs['list_specialist_pools'] @property - def delete_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation - ]: + def delete_specialist_pool(self) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -245,20 +330,18 @@ def delete_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "delete_specialist_pool" not in self._stubs: - self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", + if 'delete_specialist_pool' not in self._stubs: + self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["delete_specialist_pool"] + return self._stubs['delete_specialist_pool'] @property - def update_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation - ]: + def update_specialist_pool(self) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + operations.Operation]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -273,13 +356,15 @@ def update_specialist_pool( # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if "update_specialist_pool" not in self._stubs: - self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", + if 'update_specialist_pool' not in self._stubs: + self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs["update_specialist_pool"] + return self._stubs['update_specialist_pool'] -__all__ = ("SpecialistPoolServiceGrpcTransport",) +__all__ = ( + 'SpecialistPoolServiceGrpcTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..68639540e7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -0,0 +1,375 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import SpecialistPoolServiceGrpcTransport + + +class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): + """gRPC AsyncIO backend transport for SpecialistPoolService. + + A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs + ) + + def __init__(self, *, + host: str = 'aiplatform.googleapis.com', + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + elif api_mtls_endpoint: + warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + + host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if 'operations_client' not in self.__dict__: + self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__['operations_client'] + + @property + def create_specialist_pool(self) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the create specialist pool method over gRPC. + + Creates a SpecialistPool. + + Returns: + Callable[[~.CreateSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_specialist_pool' not in self._stubs: + self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', + request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['create_specialist_pool'] + + @property + def get_specialist_pool(self) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool]]: + r"""Return a callable for the get specialist pool method over gRPC. + + Gets a SpecialistPool. + + Returns: + Callable[[~.GetSpecialistPoolRequest], + Awaitable[~.SpecialistPool]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_specialist_pool' not in self._stubs: + self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', + request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, + response_deserializer=specialist_pool.SpecialistPool.deserialize, + ) + return self._stubs['get_specialist_pool'] + + @property + def list_specialist_pools(self) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]]: + r"""Return a callable for the list specialist pools method over gRPC. + + Lists SpecialistPools in a Location. + + Returns: + Callable[[~.ListSpecialistPoolsRequest], + Awaitable[~.ListSpecialistPoolsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_specialist_pools' not in self._stubs: + self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', + request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, + response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, + ) + return self._stubs['list_specialist_pools'] + + @property + def delete_specialist_pool(self) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the delete specialist pool method over gRPC. + + Deletes a SpecialistPool as well as all Specialists + in the pool. + + Returns: + Callable[[~.DeleteSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_specialist_pool' not in self._stubs: + self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', + request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['delete_specialist_pool'] + + @property + def update_specialist_pool(self) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation]]: + r"""Return a callable for the update specialist pool method over gRPC. + + Updates a SpecialistPool. + + Returns: + Callable[[~.UpdateSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_specialist_pool' not in self._stubs: + self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( + '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', + request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs['update_specialist_pool'] + + +__all__ = ( + 'SpecialistPoolServiceGrpcAsyncIOTransport', +) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 93508415dc..08cb2d804e 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -15,343 +15,199 @@ # limitations under the License. # -from .annotation_spec import AnnotationSpec -from .io import ( - GcsSource, - GcsDestination, - BigQuerySource, - BigQueryDestination, - ContainerRegistryDestination, -) -from .dataset import ( - Dataset, - ImportDataConfig, - ExportDataConfig, -) -from .manual_batch_tuning_parameters import ManualBatchTuningParameters -from .completion_stats import CompletionStats -from .model_evaluation_slice import ModelEvaluationSlice -from .machine_resources import ( - MachineSpec, - DedicatedResources, - AutomaticResources, - BatchDedicatedResources, - ResourcesConsumed, -) -from .deployed_model_ref import DeployedModelRef -from .env_var import EnvVar -from .explanation_metadata import ExplanationMetadata -from .explanation import ( - Explanation, - ModelExplanation, - Attribution, - ExplanationSpec, - ExplanationParameters, - SampledShapleyAttribution, -) -from .model import ( - Model, - PredictSchemata, - ModelContainerSpec, - Port, -) -from .training_pipeline import ( - TrainingPipeline, - InputDataConfig, - FractionSplit, - FilterSplit, - PredefinedSplit, - TimestampSplit, -) -from .model_evaluation import ModelEvaluation -from .batch_prediction_job import BatchPredictionJob -from .custom_job import ( - CustomJob, - CustomJobSpec, - WorkerPoolSpec, - ContainerSpec, - PythonPackageSpec, - Scheduling, -) -from .specialist_pool import SpecialistPool -from .data_labeling_job import ( - DataLabelingJob, - ActiveLearningConfig, - SampleConfig, - TrainingConfig, -) -from .study import ( - Trial, - StudySpec, - Measurement, -) -from .hyperparameter_tuning_job import HyperparameterTuningJob -from .job_service import ( - CreateCustomJobRequest, - GetCustomJobRequest, - ListCustomJobsRequest, - ListCustomJobsResponse, - DeleteCustomJobRequest, - CancelCustomJobRequest, - CreateDataLabelingJobRequest, - GetDataLabelingJobRequest, - ListDataLabelingJobsRequest, - ListDataLabelingJobsResponse, - DeleteDataLabelingJobRequest, - CancelDataLabelingJobRequest, - CreateHyperparameterTuningJobRequest, - GetHyperparameterTuningJobRequest, - ListHyperparameterTuningJobsRequest, - ListHyperparameterTuningJobsResponse, - DeleteHyperparameterTuningJobRequest, - CancelHyperparameterTuningJobRequest, - CreateBatchPredictionJobRequest, - GetBatchPredictionJobRequest, - ListBatchPredictionJobsRequest, - ListBatchPredictionJobsResponse, - DeleteBatchPredictionJobRequest, - CancelBatchPredictionJobRequest, -) -from .user_action_reference import UserActionReference -from .annotation import Annotation -from .operation import ( - GenericOperationMetadata, - DeleteOperationMetadata, -) -from .endpoint import ( - Endpoint, - DeployedModel, -) -from .prediction_service import ( - PredictRequest, - PredictResponse, - ExplainRequest, - ExplainResponse, -) -from .endpoint_service import ( - CreateEndpointRequest, - CreateEndpointOperationMetadata, - GetEndpointRequest, - ListEndpointsRequest, - ListEndpointsResponse, - UpdateEndpointRequest, - DeleteEndpointRequest, - DeployModelRequest, - DeployModelResponse, - DeployModelOperationMetadata, - UndeployModelRequest, - UndeployModelResponse, - UndeployModelOperationMetadata, -) -from .pipeline_service import ( - CreateTrainingPipelineRequest, - GetTrainingPipelineRequest, - ListTrainingPipelinesRequest, - ListTrainingPipelinesResponse, - DeleteTrainingPipelineRequest, - CancelTrainingPipelineRequest, -) -from .specialist_pool_service import ( - CreateSpecialistPoolRequest, - CreateSpecialistPoolOperationMetadata, - GetSpecialistPoolRequest, - ListSpecialistPoolsRequest, - ListSpecialistPoolsResponse, - DeleteSpecialistPoolRequest, - UpdateSpecialistPoolRequest, - UpdateSpecialistPoolOperationMetadata, -) -from .model_service import ( - UploadModelRequest, - UploadModelOperationMetadata, - UploadModelResponse, - GetModelRequest, - ListModelsRequest, - ListModelsResponse, - UpdateModelRequest, - DeleteModelRequest, - ExportModelRequest, - ExportModelOperationMetadata, - ExportModelResponse, - GetModelEvaluationRequest, - ListModelEvaluationsRequest, - ListModelEvaluationsResponse, - GetModelEvaluationSliceRequest, - ListModelEvaluationSlicesRequest, - ListModelEvaluationSlicesResponse, -) -from .data_item import DataItem -from .dataset_service import ( - CreateDatasetRequest, - CreateDatasetOperationMetadata, - GetDatasetRequest, - UpdateDatasetRequest, - ListDatasetsRequest, - ListDatasetsResponse, - DeleteDatasetRequest, - ImportDataRequest, - ImportDataResponse, - ImportDataOperationMetadata, - ExportDataRequest, - ExportDataResponse, - ExportDataOperationMetadata, - ListDataItemsRequest, - ListDataItemsResponse, - GetAnnotationSpecRequest, - ListAnnotationsRequest, - ListAnnotationsResponse, -) +from .annotation_spec import (AnnotationSpec, ) +from .io import (GcsSource, GcsDestination, BigQuerySource, BigQueryDestination, ContainerRegistryDestination, ) +from .dataset import (Dataset, ImportDataConfig, ExportDataConfig, ) +from .manual_batch_tuning_parameters import (ManualBatchTuningParameters, ) +from .completion_stats import (CompletionStats, ) +from .model_evaluation_slice import (ModelEvaluationSlice, ) +from .machine_resources import (MachineSpec, DedicatedResources, AutomaticResources, BatchDedicatedResources, ResourcesConsumed, ) +from .deployed_model_ref import (DeployedModelRef, ) +from .env_var import (EnvVar, ) +from .explanation_metadata import (ExplanationMetadata, ) +from .explanation import (Explanation, ModelExplanation, Attribution, ExplanationSpec, ExplanationParameters, SampledShapleyAttribution, ) +from .model import (Model, PredictSchemata, ModelContainerSpec, Port, ) +from .training_pipeline import (TrainingPipeline, InputDataConfig, FractionSplit, FilterSplit, PredefinedSplit, TimestampSplit, ) +from .model_evaluation import (ModelEvaluation, ) +from .migratable_resource import (MigratableResource, ) +from .operation import (GenericOperationMetadata, DeleteOperationMetadata, ) +from .migration_service import (SearchMigratableResourcesRequest, SearchMigratableResourcesResponse, BatchMigrateResourcesRequest, MigrateResourceRequest, BatchMigrateResourcesResponse, MigrateResourceResponse, BatchMigrateResourcesOperationMetadata, ) +from .batch_prediction_job import (BatchPredictionJob, ) +from .custom_job import (CustomJob, CustomJobSpec, WorkerPoolSpec, ContainerSpec, PythonPackageSpec, Scheduling, ) +from .specialist_pool import (SpecialistPool, ) +from .data_labeling_job import (DataLabelingJob, ActiveLearningConfig, SampleConfig, TrainingConfig, ) +from .study import (Trial, StudySpec, Measurement, ) +from .hyperparameter_tuning_job import (HyperparameterTuningJob, ) +from .job_service import (CreateCustomJobRequest, GetCustomJobRequest, ListCustomJobsRequest, ListCustomJobsResponse, DeleteCustomJobRequest, CancelCustomJobRequest, CreateDataLabelingJobRequest, GetDataLabelingJobRequest, ListDataLabelingJobsRequest, ListDataLabelingJobsResponse, DeleteDataLabelingJobRequest, CancelDataLabelingJobRequest, CreateHyperparameterTuningJobRequest, GetHyperparameterTuningJobRequest, ListHyperparameterTuningJobsRequest, ListHyperparameterTuningJobsResponse, DeleteHyperparameterTuningJobRequest, CancelHyperparameterTuningJobRequest, CreateBatchPredictionJobRequest, GetBatchPredictionJobRequest, ListBatchPredictionJobsRequest, ListBatchPredictionJobsResponse, DeleteBatchPredictionJobRequest, CancelBatchPredictionJobRequest, ) +from .user_action_reference import (UserActionReference, ) +from .annotation import (Annotation, ) +from .endpoint import (Endpoint, DeployedModel, ) +from .prediction_service import (PredictRequest, PredictResponse, ExplainRequest, ExplainResponse, ) +from .endpoint_service import (CreateEndpointRequest, CreateEndpointOperationMetadata, GetEndpointRequest, ListEndpointsRequest, ListEndpointsResponse, UpdateEndpointRequest, DeleteEndpointRequest, DeployModelRequest, DeployModelResponse, DeployModelOperationMetadata, UndeployModelRequest, UndeployModelResponse, UndeployModelOperationMetadata, ) +from .pipeline_service import (CreateTrainingPipelineRequest, GetTrainingPipelineRequest, ListTrainingPipelinesRequest, ListTrainingPipelinesResponse, DeleteTrainingPipelineRequest, CancelTrainingPipelineRequest, ) +from .specialist_pool_service import (CreateSpecialistPoolRequest, CreateSpecialistPoolOperationMetadata, GetSpecialistPoolRequest, ListSpecialistPoolsRequest, ListSpecialistPoolsResponse, DeleteSpecialistPoolRequest, UpdateSpecialistPoolRequest, UpdateSpecialistPoolOperationMetadata, ) +from .model_service import (UploadModelRequest, UploadModelOperationMetadata, UploadModelResponse, GetModelRequest, ListModelsRequest, ListModelsResponse, UpdateModelRequest, DeleteModelRequest, ExportModelRequest, ExportModelOperationMetadata, ExportModelResponse, GetModelEvaluationRequest, ListModelEvaluationsRequest, ListModelEvaluationsResponse, GetModelEvaluationSliceRequest, ListModelEvaluationSlicesRequest, ListModelEvaluationSlicesResponse, ) +from .data_item import (DataItem, ) +from .dataset_service import (CreateDatasetRequest, CreateDatasetOperationMetadata, GetDatasetRequest, UpdateDatasetRequest, ListDatasetsRequest, ListDatasetsResponse, DeleteDatasetRequest, ImportDataRequest, ImportDataResponse, ImportDataOperationMetadata, ExportDataRequest, ExportDataResponse, ExportDataOperationMetadata, ListDataItemsRequest, ListDataItemsResponse, GetAnnotationSpecRequest, ListAnnotationsRequest, ListAnnotationsResponse, ) __all__ = ( - "AnnotationSpec", - "GcsSource", - "GcsDestination", - "BigQuerySource", - "BigQueryDestination", - "ContainerRegistryDestination", - "Dataset", - "ImportDataConfig", - "ExportDataConfig", - "ManualBatchTuningParameters", - "CompletionStats", - "ModelEvaluationSlice", - "MachineSpec", - "DedicatedResources", - "AutomaticResources", - "BatchDedicatedResources", - "ResourcesConsumed", - "DeployedModelRef", - "EnvVar", - "ExplanationMetadata", - "Explanation", - "ModelExplanation", - "Attribution", - "ExplanationSpec", - "ExplanationParameters", - "SampledShapleyAttribution", - "Model", - "PredictSchemata", - "ModelContainerSpec", - "Port", - "TrainingPipeline", - "InputDataConfig", - "FractionSplit", - "FilterSplit", - "PredefinedSplit", - "TimestampSplit", - "ModelEvaluation", - "BatchPredictionJob", - "CustomJob", - "CustomJobSpec", - "WorkerPoolSpec", - "ContainerSpec", - "PythonPackageSpec", - "Scheduling", - "SpecialistPool", - "DataLabelingJob", - "ActiveLearningConfig", - "SampleConfig", - "TrainingConfig", - "Trial", - "StudySpec", - "Measurement", - "HyperparameterTuningJob", - "CreateCustomJobRequest", - "GetCustomJobRequest", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "DeleteCustomJobRequest", - "CancelCustomJobRequest", - "CreateDataLabelingJobRequest", - "GetDataLabelingJobRequest", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "DeleteDataLabelingJobRequest", - "CancelDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "GetHyperparameterTuningJobRequest", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "DeleteHyperparameterTuningJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "GetBatchPredictionJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "DeleteBatchPredictionJobRequest", - "CancelBatchPredictionJobRequest", - "UserActionReference", - "Annotation", - "GenericOperationMetadata", - "DeleteOperationMetadata", - "Endpoint", - "DeployedModel", - "PredictRequest", - "PredictResponse", - "ExplainRequest", - "ExplainResponse", - "CreateEndpointRequest", - "CreateEndpointOperationMetadata", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UpdateEndpointRequest", - "DeleteEndpointRequest", - "DeployModelRequest", - "DeployModelResponse", - "DeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UndeployModelOperationMetadata", - "CreateTrainingPipelineRequest", - "GetTrainingPipelineRequest", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "DeleteTrainingPipelineRequest", - "CancelTrainingPipelineRequest", - "CreateSpecialistPoolRequest", - "CreateSpecialistPoolOperationMetadata", - "GetSpecialistPoolRequest", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "DeleteSpecialistPoolRequest", - "UpdateSpecialistPoolRequest", - "UpdateSpecialistPoolOperationMetadata", - "UploadModelRequest", - "UploadModelOperationMetadata", - "UploadModelResponse", - "GetModelRequest", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", - "DeleteModelRequest", - "ExportModelRequest", - "ExportModelOperationMetadata", - "ExportModelResponse", - "GetModelEvaluationRequest", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "GetModelEvaluationSliceRequest", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", - "DataItem", - "CreateDatasetRequest", - "CreateDatasetOperationMetadata", - "GetDatasetRequest", - "UpdateDatasetRequest", - "ListDatasetsRequest", - "ListDatasetsResponse", - "DeleteDatasetRequest", - "ImportDataRequest", - "ImportDataResponse", - "ImportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportDataOperationMetadata", - "ListDataItemsRequest", - "ListDataItemsResponse", - "GetAnnotationSpecRequest", - "ListAnnotationsRequest", - "ListAnnotationsResponse", + 'AnnotationSpec', + 'GcsSource', + 'GcsDestination', + 'BigQuerySource', + 'BigQueryDestination', + 'ContainerRegistryDestination', + 'Dataset', + 'ImportDataConfig', + 'ExportDataConfig', + 'ManualBatchTuningParameters', + 'CompletionStats', + 'ModelEvaluationSlice', + 'MachineSpec', + 'DedicatedResources', + 'AutomaticResources', + 'BatchDedicatedResources', + 'ResourcesConsumed', + 'DeployedModelRef', + 'EnvVar', + 'ExplanationMetadata', + 'Explanation', + 'ModelExplanation', + 'Attribution', + 'ExplanationSpec', + 'ExplanationParameters', + 'SampledShapleyAttribution', + 'Model', + 'PredictSchemata', + 'ModelContainerSpec', + 'Port', + 'TrainingPipeline', + 'InputDataConfig', + 'FractionSplit', + 'FilterSplit', + 'PredefinedSplit', + 'TimestampSplit', + 'ModelEvaluation', + 'MigratableResource', + 'GenericOperationMetadata', + 'DeleteOperationMetadata', + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'BatchMigrateResourcesRequest', + 'MigrateResourceRequest', + 'BatchMigrateResourcesResponse', + 'MigrateResourceResponse', + 'BatchMigrateResourcesOperationMetadata', + 'BatchPredictionJob', + 'CustomJob', + 'CustomJobSpec', + 'WorkerPoolSpec', + 'ContainerSpec', + 'PythonPackageSpec', + 'Scheduling', + 'SpecialistPool', + 'DataLabelingJob', + 'ActiveLearningConfig', + 'SampleConfig', + 'TrainingConfig', + 'Trial', + 'StudySpec', + 'Measurement', + 'HyperparameterTuningJob', + 'CreateCustomJobRequest', + 'GetCustomJobRequest', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'DeleteCustomJobRequest', + 'CancelCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'GetDataLabelingJobRequest', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'DeleteDataLabelingJobRequest', + 'CancelDataLabelingJobRequest', + 'CreateHyperparameterTuningJobRequest', + 'GetHyperparameterTuningJobRequest', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'DeleteHyperparameterTuningJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CreateBatchPredictionJobRequest', + 'GetBatchPredictionJobRequest', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'DeleteBatchPredictionJobRequest', + 'CancelBatchPredictionJobRequest', + 'UserActionReference', + 'Annotation', + 'Endpoint', + 'DeployedModel', + 'PredictRequest', + 'PredictResponse', + 'ExplainRequest', + 'ExplainResponse', + 'CreateEndpointRequest', + 'CreateEndpointOperationMetadata', + 'GetEndpointRequest', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'UpdateEndpointRequest', + 'DeleteEndpointRequest', + 'DeployModelRequest', + 'DeployModelResponse', + 'DeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UndeployModelOperationMetadata', + 'CreateTrainingPipelineRequest', + 'GetTrainingPipelineRequest', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'DeleteTrainingPipelineRequest', + 'CancelTrainingPipelineRequest', + 'CreateSpecialistPoolRequest', + 'CreateSpecialistPoolOperationMetadata', + 'GetSpecialistPoolRequest', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'DeleteSpecialistPoolRequest', + 'UpdateSpecialistPoolRequest', + 'UpdateSpecialistPoolOperationMetadata', + 'UploadModelRequest', + 'UploadModelOperationMetadata', + 'UploadModelResponse', + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'UpdateModelRequest', + 'DeleteModelRequest', + 'ExportModelRequest', + 'ExportModelOperationMetadata', + 'ExportModelResponse', + 'GetModelEvaluationRequest', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'GetModelEvaluationSliceRequest', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', + 'DataItem', + 'CreateDatasetRequest', + 'CreateDatasetOperationMetadata', + 'GetDatasetRequest', + 'UpdateDatasetRequest', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'DeleteDatasetRequest', + 'ImportDataRequest', + 'ImportDataResponse', + 'ImportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'ExportDataOperationMetadata', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'GetAnnotationSpecRequest', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', ) diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py index 337b0eeaf5..e82a142396 100644 --- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"AcceleratorType",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'AcceleratorType', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index 34f3edfa5e..f3f36fb568 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -24,7 +24,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"Annotation",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Annotation', + }, ) @@ -88,14 +91,27 @@ class Annotation(proto.Message): """ name = proto.Field(proto.STRING, number=1) + payload_schema_uri = proto.Field(proto.STRING, number=2) - payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + + payload = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) + + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) + etag = proto.Field(proto.STRING, number=8) - annotation_source = proto.Field( - proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, + + annotation_source = proto.Field(proto.MESSAGE, number=5, + message=user_action_reference.UserActionReference, ) + labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index 4719fb12ce..068aca741b 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -22,7 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"AnnotationSpec",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'AnnotationSpec', + }, ) @@ -52,9 +55,17 @@ class AnnotationSpec(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + create_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 332faaa6a9..55c81889e7 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -18,22 +18,21 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - completion_stats as gca_completion_stats, -) +from google.cloud.aiplatform_v1beta1.types import completion_stats as gca_completion_stats from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import ( - manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, -) +from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"BatchPredictionJob",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'BatchPredictionJob', + }, ) @@ -101,27 +100,10 @@ class BatchPredictionJob(proto.Message): This can only be set to true for AutoML tabular Models, and only when the output destination is BigQuery. When it's true, the batch prediction output will include a column - named ``feature_attributions``. - - For AutoML tabular Models, the value of the - ``feature_attributions`` column is a struct that maps from - string to number. The keys in the map are the names of the - features. The values in the map are the how much the - features contribute to the predicted result. Features are - defined as follows: - - - A scalar column defines a feature of the same name as the - column. - - - A struct column defines multiple features, one feature - per leaf field. The feature name is the fully qualified - path for the leaf field, separated by ".". For example a - column ``key1`` in the format of {"value1": {"prop1": - number}, "value2": number} defines two features: - ``key1.value1.prop1`` and ``key1.value2`` - - Attributions of each feature is represented as an extra - column in the batch prediction output BigQuery table. + named ``explanation``. The value is a struct that conforms + to the + ``Explanation`` + object. output_info (~.batch_prediction_job.BatchPredictionJob.OutputInfo): Output only. Information further describing the output of this job. @@ -171,7 +153,6 @@ class BatchPredictionJob(proto.Message): See https://goo.gl/xmQnxf for more information and examples of labels. """ - class InputConfig(proto.Message): r"""Configures the input to ``BatchPredictionJob``. @@ -198,10 +179,14 @@ class InputConfig(proto.Message): ``supported_input_storage_formats``. """ - gcs_source = proto.Field(proto.MESSAGE, number=2, message=io.GcsSource,) - bigquery_source = proto.Field( - proto.MESSAGE, number=3, message=io.BigQuerySource, + gcs_source = proto.Field(proto.MESSAGE, number=2, oneof='source', + message=io.GcsSource, + ) + + bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', + message=io.BigQuerySource, ) + instances_format = proto.Field(proto.STRING, number=1) class OutputConfig(proto.Message): @@ -270,12 +255,14 @@ class OutputConfig(proto.Message): ``supported_output_storage_formats``. """ - gcs_destination = proto.Field( - proto.MESSAGE, number=2, message=io.GcsDestination, + gcs_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', + message=io.GcsDestination, ) - bigquery_destination = proto.Field( - proto.MESSAGE, number=3, message=io.BigQueryDestination, + + bigquery_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', + message=io.BigQueryDestination, ) + predictions_format = proto.Field(proto.STRING, number=1) class OutputInfo(proto.Message): @@ -293,40 +280,78 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field(proto.STRING, number=1) - bigquery_output_dataset = proto.Field(proto.STRING, number=2) + gcs_output_directory = proto.Field(proto.STRING, number=1, oneof='output_location') + + bigquery_output_dataset = proto.Field(proto.STRING, number=2, oneof='output_location') name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + model = proto.Field(proto.STRING, number=3) - input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) - model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) - dedicated_resources = proto.Field( - proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, + + input_config = proto.Field(proto.MESSAGE, number=4, + message=InputConfig, ) - manual_batch_tuning_parameters = proto.Field( - proto.MESSAGE, - number=8, + + model_parameters = proto.Field(proto.MESSAGE, number=5, + message=struct.Value, + ) + + output_config = proto.Field(proto.MESSAGE, number=6, + message=OutputConfig, + ) + + dedicated_resources = proto.Field(proto.MESSAGE, number=7, + message=machine_resources.BatchDedicatedResources, + ) + + manual_batch_tuning_parameters = proto.Field(proto.MESSAGE, number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) + generate_explanation = proto.Field(proto.BOOL, number=23) - output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) - partial_failures = proto.RepeatedField( - proto.MESSAGE, number=12, message=status.Status, + + output_info = proto.Field(proto.MESSAGE, number=9, + message=OutputInfo, + ) + + state = proto.Field(proto.ENUM, number=10, + enum=job_state.JobState, + ) + + error = proto.Field(proto.MESSAGE, number=11, + message=status.Status, ) - resources_consumed = proto.Field( - proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, + + partial_failures = proto.RepeatedField(proto.MESSAGE, number=12, + message=status.Status, ) - completion_stats = proto.Field( - proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, + + resources_consumed = proto.Field(proto.MESSAGE, number=13, + message=machine_resources.ResourcesConsumed, + ) + + completion_stats = proto.Field(proto.MESSAGE, number=14, + message=gca_completion_stats.CompletionStats, ) - create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) + + create_time = proto.Field(proto.MESSAGE, number=15, + message=timestamp.Timestamp, + ) + + start_time = proto.Field(proto.MESSAGE, number=16, + message=timestamp.Timestamp, + ) + + end_time = proto.Field(proto.MESSAGE, number=17, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=18, + message=timestamp.Timestamp, + ) + labels = proto.MapField(proto.STRING, proto.STRING, number=19) diff --git a/google/cloud/aiplatform_v1beta1/types/completion_stats.py b/google/cloud/aiplatform_v1beta1/types/completion_stats.py index 22f1fa7975..3874f412df 100644 --- a/google/cloud/aiplatform_v1beta1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/completion_stats.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"CompletionStats",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'CompletionStats', + }, ) @@ -46,7 +49,9 @@ class CompletionStats(proto.Message): """ successful_count = proto.Field(proto.INT64, number=1) + failed_count = proto.Field(proto.INT64, number=2) + incomplete_count = proto.Field(proto.INT64, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 3d466ab72f..7ab803bec1 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -27,14 +27,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CustomJob", - "CustomJobSpec", - "WorkerPoolSpec", - "ContainerSpec", - "PythonPackageSpec", - "Scheduling", + 'CustomJob', + 'CustomJobSpec', + 'WorkerPoolSpec', + 'ContainerSpec', + 'PythonPackageSpec', + 'Scheduling', }, ) @@ -86,14 +86,37 @@ class CustomJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) - job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) - state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + + job_spec = proto.Field(proto.MESSAGE, number=4, + message='CustomJobSpec', + ) + + state = proto.Field(proto.ENUM, number=5, + enum=job_state.JobState, + ) + + create_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) + + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) + + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) + + error = proto.Field(proto.MESSAGE, number=10, + message=status.Status, + ) + labels = proto.MapField(proto.STRING, proto.STRING, number=11) @@ -139,12 +162,16 @@ class CustomJobSpec(proto.Message): ``//logs/`` """ - worker_pool_specs = proto.RepeatedField( - proto.MESSAGE, number=1, message="WorkerPoolSpec", + worker_pool_specs = proto.RepeatedField(proto.MESSAGE, number=1, + message='WorkerPoolSpec', + ) + + scheduling = proto.Field(proto.MESSAGE, number=3, + message='Scheduling', ) - scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) - base_output_directory = proto.Field( - proto.MESSAGE, number=6, message=io.GcsDestination, + + base_output_directory = proto.Field(proto.MESSAGE, number=6, + message=io.GcsDestination, ) @@ -164,13 +191,18 @@ class WorkerPoolSpec(proto.Message): use for this worker pool. """ - container_spec = proto.Field(proto.MESSAGE, number=6, message="ContainerSpec",) - python_package_spec = proto.Field( - proto.MESSAGE, number=7, message="PythonPackageSpec", + container_spec = proto.Field(proto.MESSAGE, number=6, oneof='task', + message='ContainerSpec', ) - machine_spec = proto.Field( - proto.MESSAGE, number=1, message=machine_resources.MachineSpec, + + python_package_spec = proto.Field(proto.MESSAGE, number=7, oneof='task', + message='PythonPackageSpec', + ) + + machine_spec = proto.Field(proto.MESSAGE, number=1, + message=machine_resources.MachineSpec, ) + replica_count = proto.Field(proto.INT64, number=2) @@ -192,7 +224,9 @@ class ContainerSpec(proto.Message): """ image_uri = proto.Field(proto.STRING, number=1) + command = proto.RepeatedField(proto.STRING, number=2) + args = proto.RepeatedField(proto.STRING, number=3) @@ -221,8 +255,11 @@ class PythonPackageSpec(proto.Message): """ executor_image_uri = proto.Field(proto.STRING, number=1) + package_uris = proto.RepeatedField(proto.STRING, number=2) + python_module = proto.Field(proto.STRING, number=3) + args = proto.RepeatedField(proto.STRING, number=4) @@ -241,7 +278,10 @@ class Scheduling(proto.Message): to workers leaving and joining a job. """ - timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + timeout = proto.Field(proto.MESSAGE, number=1, + message=duration.Duration, + ) + restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index 418e8cc739..961a153172 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -23,7 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"DataItem",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'DataItem', + }, ) @@ -69,10 +72,21 @@ class DataItem(proto.Message): """ name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + create_time = proto.Field(proto.MESSAGE, number=2, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) + labels = proto.MapField(proto.STRING, proto.STRING, number=3) - payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + + payload = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) + etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index 19da27e6a9..2d10060738 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -25,12 +25,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "DataLabelingJob", - "ActiveLearningConfig", - "SampleConfig", - "TrainingConfig", + 'DataLabelingJob', + 'ActiveLearningConfig', + 'SampleConfig', + 'TrainingConfig', }, ) @@ -128,22 +128,47 @@ class DataLabelingJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + datasets = proto.RepeatedField(proto.STRING, number=3) + annotation_labels = proto.MapField(proto.STRING, proto.STRING, number=12) + labeler_count = proto.Field(proto.INT32, number=4) + instruction_uri = proto.Field(proto.STRING, number=5) + inputs_schema_uri = proto.Field(proto.STRING, number=6) - inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) - state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) + + inputs = proto.Field(proto.MESSAGE, number=7, + message=struct.Value, + ) + + state = proto.Field(proto.ENUM, number=8, + enum=job_state.JobState, + ) + labeling_progress = proto.Field(proto.INT32, number=13) - current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) - create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) + + current_spend = proto.Field(proto.MESSAGE, number=14, + message=money.Money, + ) + + create_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=10, + message=timestamp.Timestamp, + ) + labels = proto.MapField(proto.STRING, proto.STRING, number=11) + specialist_pools = proto.RepeatedField(proto.STRING, number=16) - active_learning_config = proto.Field( - proto.MESSAGE, number=21, message="ActiveLearningConfig", + + active_learning_config = proto.Field(proto.MESSAGE, number=21, + message='ActiveLearningConfig', ) @@ -172,10 +197,17 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field(proto.INT64, number=1) - max_data_item_percentage = proto.Field(proto.INT32, number=2) - sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) - training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) + max_data_item_count = proto.Field(proto.INT64, number=1, oneof='human_labeling_budget') + + max_data_item_percentage = proto.Field(proto.INT32, number=2, oneof='human_labeling_budget') + + sample_config = proto.Field(proto.MESSAGE, number=3, + message='SampleConfig', + ) + + training_config = proto.Field(proto.MESSAGE, number=4, + message='TrainingConfig', + ) class SampleConfig(proto.Message): @@ -196,7 +228,6 @@ class SampleConfig(proto.Message): strategy will decide which data should be selected for human labeling in every batch. """ - class SampleStrategy(proto.Enum): r"""Sample strategy decides which subset of DataItems should be selected for human labeling in every batch. @@ -204,9 +235,13 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field(proto.INT32, number=1) - following_batch_sample_percentage = proto.Field(proto.INT32, number=3) - sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) + initial_batch_sample_percentage = proto.Field(proto.INT32, number=1, oneof='initial_batch_sample_size') + + following_batch_sample_percentage = proto.Field(proto.INT32, number=3, oneof='following_batch_sample_size') + + sample_strategy = proto.Field(proto.ENUM, number=5, + enum=SampleStrategy, + ) class TrainingConfig(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 3675d8f42a..5138badf1f 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -24,8 +24,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"Dataset", "ImportDataConfig", "ExportDataConfig",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Dataset', + 'ImportDataConfig', + 'ExportDataConfig', + }, ) @@ -83,12 +87,25 @@ class Dataset(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + metadata_schema_uri = proto.Field(proto.STRING, number=3) - metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + metadata = proto.Field(proto.MESSAGE, number=8, + message=struct.Value, + ) + + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) + etag = proto.Field(proto.STRING, number=6) + labels = proto.MapField(proto.STRING, proto.STRING, number=7) @@ -124,8 +141,12 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field(proto.MESSAGE, number=1, message=io.GcsSource,) + gcs_source = proto.Field(proto.MESSAGE, number=1, oneof='source', + message=io.GcsSource, + ) + data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) + import_schema_uri = proto.Field(proto.STRING, number=4) @@ -153,7 +174,10 @@ class ExportDataConfig(proto.Message): ``ListAnnotations``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=1, message=io.GcsDestination,) + gcs_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', + message=io.GcsDestination, + ) + annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 56b51a97cb..594484375c 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -26,26 +26,26 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateDatasetRequest", - "CreateDatasetOperationMetadata", - "GetDatasetRequest", - "UpdateDatasetRequest", - "ListDatasetsRequest", - "ListDatasetsResponse", - "DeleteDatasetRequest", - "ImportDataRequest", - "ImportDataResponse", - "ImportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportDataOperationMetadata", - "ListDataItemsRequest", - "ListDataItemsResponse", - "GetAnnotationSpecRequest", - "ListAnnotationsRequest", - "ListAnnotationsResponse", + 'CreateDatasetRequest', + 'CreateDatasetOperationMetadata', + 'GetDatasetRequest', + 'UpdateDatasetRequest', + 'ListDatasetsRequest', + 'ListDatasetsResponse', + 'DeleteDatasetRequest', + 'ImportDataRequest', + 'ImportDataResponse', + 'ImportDataOperationMetadata', + 'ExportDataRequest', + 'ExportDataResponse', + 'ExportDataOperationMetadata', + 'ListDataItemsRequest', + 'ListDataItemsResponse', + 'GetAnnotationSpecRequest', + 'ListAnnotationsRequest', + 'ListAnnotationsResponse', }, ) @@ -64,7 +64,10 @@ class CreateDatasetRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) + + dataset = proto.Field(proto.MESSAGE, number=2, + message=gca_dataset.Dataset, + ) class CreateDatasetOperationMetadata(proto.Message): @@ -76,8 +79,8 @@ class CreateDatasetOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -93,7 +96,10 @@ class GetDatasetRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class UpdateDatasetRequest(proto.Message): @@ -116,8 +122,13 @@ class UpdateDatasetRequest(proto.Message): - ``labels`` """ - dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + dataset = proto.Field(proto.MESSAGE, number=1, + message=gca_dataset.Dataset, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class ListDatasetsRequest(proto.Message): @@ -147,10 +158,17 @@ class ListDatasetsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) + order_by = proto.Field(proto.STRING, number=6) @@ -170,9 +188,10 @@ class ListDatasetsResponse(proto.Message): def raw_page(self): return self - datasets = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_dataset.Dataset, + datasets = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_dataset.Dataset, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -205,8 +224,9 @@ class ImportDataRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - import_configs = proto.RepeatedField( - proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, + + import_configs = proto.RepeatedField(proto.MESSAGE, number=2, + message=gca_dataset.ImportDataConfig, ) @@ -225,8 +245,8 @@ class ImportDataOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -243,8 +263,9 @@ class ExportDataRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - export_config = proto.Field( - proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, + + export_config = proto.Field(proto.MESSAGE, number=2, + message=gca_dataset.ExportDataConfig, ) @@ -274,9 +295,10 @@ class ExportDataOperationMetadata(proto.Message): the directory. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) + gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -304,10 +326,17 @@ class ListDataItemsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) + order_by = proto.Field(proto.STRING, number=6) @@ -327,9 +356,10 @@ class ListDataItemsResponse(proto.Message): def raw_page(self): return self - data_items = proto.RepeatedField( - proto.MESSAGE, number=1, message=data_item.DataItem, + data_items = proto.RepeatedField(proto.MESSAGE, number=1, + message=data_item.DataItem, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -347,7 +377,10 @@ class GetAnnotationSpecRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class ListAnnotationsRequest(proto.Message): @@ -375,10 +408,17 @@ class ListAnnotationsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) + order_by = proto.Field(proto.STRING, number=6) @@ -398,9 +438,10 @@ class ListAnnotationsResponse(proto.Message): def raw_page(self): return self - annotations = proto.RepeatedField( - proto.MESSAGE, number=1, message=annotation.Annotation, + annotations = proto.RepeatedField(proto.MESSAGE, number=1, + message=annotation.Annotation, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py index 6a7f18850f..aa5c8424aa 100644 --- a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"DeployedModelRef",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'DeployedModelRef', + }, ) @@ -35,6 +38,7 @@ class DeployedModelRef(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) + deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 315a9de179..7d1275703d 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -24,7 +24,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"Endpoint", "DeployedModel",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Endpoint', + 'DeployedModel', + }, ) @@ -82,16 +86,28 @@ class Endpoint(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + description = proto.Field(proto.STRING, number=3) - deployed_models = proto.RepeatedField( - proto.MESSAGE, number=4, message="DeployedModel", + + deployed_models = proto.RepeatedField(proto.MESSAGE, number=4, + message='DeployedModel', ) + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) + etag = proto.Field(proto.STRING, number=6) + labels = proto.MapField(proto.STRING, proto.STRING, number=7) - create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + + create_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=9, + message=timestamp.Timestamp, + ) class DeployedModel(proto.Message): @@ -158,20 +174,30 @@ class DeployedModel(proto.Message): option. """ - dedicated_resources = proto.Field( - proto.MESSAGE, number=7, message=machine_resources.DedicatedResources, + dedicated_resources = proto.Field(proto.MESSAGE, number=7, oneof='prediction_resources', + message=machine_resources.DedicatedResources, ) - automatic_resources = proto.Field( - proto.MESSAGE, number=8, message=machine_resources.AutomaticResources, + + automatic_resources = proto.Field(proto.MESSAGE, number=8, oneof='prediction_resources', + message=machine_resources.AutomaticResources, ) + id = proto.Field(proto.STRING, number=1) + model = proto.Field(proto.STRING, number=2) + display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - explanation_spec = proto.Field( - proto.MESSAGE, number=9, message=explanation.ExplanationSpec, + + create_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, ) + + explanation_spec = proto.Field(proto.MESSAGE, number=9, + message=explanation.ExplanationSpec, + ) + enable_container_logging = proto.Field(proto.BOOL, number=12) + enable_access_logging = proto.Field(proto.BOOL, number=13) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index 43e8eacdfb..acbf58d123 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -24,21 +24,21 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateEndpointRequest", - "CreateEndpointOperationMetadata", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UpdateEndpointRequest", - "DeleteEndpointRequest", - "DeployModelRequest", - "DeployModelResponse", - "DeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UndeployModelOperationMetadata", + 'CreateEndpointRequest', + 'CreateEndpointOperationMetadata', + 'GetEndpointRequest', + 'ListEndpointsRequest', + 'ListEndpointsResponse', + 'UpdateEndpointRequest', + 'DeleteEndpointRequest', + 'DeployModelRequest', + 'DeployModelResponse', + 'DeployModelOperationMetadata', + 'UndeployModelRequest', + 'UndeployModelResponse', + 'UndeployModelOperationMetadata', }, ) @@ -57,7 +57,10 @@ class CreateEndpointRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) + + endpoint = proto.Field(proto.MESSAGE, number=2, + message=gca_endpoint.Endpoint, + ) class CreateEndpointOperationMetadata(proto.Message): @@ -69,8 +72,8 @@ class CreateEndpointOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -135,10 +138,16 @@ class ListEndpointsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListEndpointsResponse(proto.Message): @@ -158,9 +167,10 @@ class ListEndpointsResponse(proto.Message): def raw_page(self): return self - endpoints = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, + endpoints = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_endpoint.Endpoint, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -177,8 +187,13 @@ class UpdateEndpointRequest(proto.Message): resource. """ - endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + endpoint = proto.Field(proto.MESSAGE, number=1, + message=gca_endpoint.Endpoint, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class DeleteEndpointRequest(proto.Message): @@ -230,9 +245,11 @@ class DeployModelRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - deployed_model = proto.Field( - proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, + + deployed_model = proto.Field(proto.MESSAGE, number=2, + message=gca_endpoint.DeployedModel, ) + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -246,8 +263,8 @@ class DeployModelResponse(proto.Message): the Endpoint. """ - deployed_model = proto.Field( - proto.MESSAGE, number=1, message=gca_endpoint.DeployedModel, + deployed_model = proto.Field(proto.MESSAGE, number=1, + message=gca_endpoint.DeployedModel, ) @@ -260,8 +277,8 @@ class DeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -289,7 +306,9 @@ class UndeployModelRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) + deployed_model_id = proto.Field(proto.STRING, number=2) + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -308,8 +327,8 @@ class UndeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/env_var.py b/google/cloud/aiplatform_v1beta1/types/env_var.py index 0c22313d63..3eb6531af1 100644 --- a/google/cloud/aiplatform_v1beta1/types/env_var.py +++ b/google/cloud/aiplatform_v1beta1/types/env_var.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"EnvVar",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'EnvVar', + }, ) @@ -43,6 +46,7 @@ class EnvVar(proto.Message): """ name = proto.Field(proto.STRING, number=1) + value = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 5e20ef2699..41778b055c 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -23,21 +23,22 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "Explanation", - "ModelExplanation", - "Attribution", - "ExplanationSpec", - "ExplanationParameters", - "SampledShapleyAttribution", + 'Explanation', + 'ModelExplanation', + 'Attribution', + 'ExplanationSpec', + 'ExplanationParameters', + 'SampledShapleyAttribution', }, ) class Explanation(proto.Message): - r"""Explanation of a ``prediction`` produced - by the Model on a given + r"""Explanation of a prediction (provided in + ``PredictResponse.predictions`` + ) produced by the Model on a given ``instance``. Currently, only AutoML tabular Models support explanation. @@ -58,7 +59,9 @@ class Explanation(proto.Message): explaining. """ - attributions = proto.RepeatedField(proto.MESSAGE, number=1, message="Attribution",) + attributions = proto.RepeatedField(proto.MESSAGE, number=1, + message='Attribution', + ) class ModelExplanation(proto.Message): @@ -97,8 +100,8 @@ class ModelExplanation(proto.Message): is not populated. """ - mean_attributions = proto.RepeatedField( - proto.MESSAGE, number=1, message="Attribution", + mean_attributions = proto.RepeatedField(proto.MESSAGE, number=1, + message='Attribution', ) @@ -203,10 +206,17 @@ class Attribution(proto.Message): """ baseline_output_value = proto.Field(proto.DOUBLE, number=1) + instance_output_value = proto.Field(proto.DOUBLE, number=2) - feature_attributions = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + + feature_attributions = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) + output_index = proto.RepeatedField(proto.INT32, number=4) + output_display_name = proto.Field(proto.STRING, number=5) + approximation_error = proto.Field(proto.DOUBLE, number=6) @@ -223,9 +233,12 @@ class ExplanationSpec(proto.Message): input and output for explanation. """ - parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) - metadata = proto.Field( - proto.MESSAGE, number=2, message=explanation_metadata.ExplanationMetadata, + parameters = proto.Field(proto.MESSAGE, number=1, + message='ExplanationParameters', + ) + + metadata = proto.Field(proto.MESSAGE, number=2, + message=explanation_metadata.ExplanationMetadata, ) @@ -241,8 +254,8 @@ class ExplanationParameters(proto.Message): considering all subsets of features. """ - sampled_shapley_attribution = proto.Field( - proto.MESSAGE, number=1, message="SampledShapleyAttribution", + sampled_shapley_attribution = proto.Field(proto.MESSAGE, number=1, + message='SampledShapleyAttribution', ) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 1b9f005857..a6fed18554 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -22,7 +22,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"ExplanationMetadata",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ExplanationMetadata', + }, ) @@ -58,7 +61,6 @@ class ExplanationMetadata(proto.Message): output URI will point to a location where the user only has a read access. """ - class InputMetadata(proto.Message): r"""Metadata of the input of a feature. @@ -81,8 +83,8 @@ class InputMetadata(proto.Message): ``instance_schema_uri``. """ - input_baselines = proto.RepeatedField( - proto.MESSAGE, number=1, message=struct.Value, + input_baselines = proto.RepeatedField(proto.MESSAGE, number=1, + message=struct.Value, ) class OutputMetadata(proto.Message): @@ -118,17 +120,20 @@ class OutputMetadata(proto.Message): for a specific output. """ - index_display_name_mapping = proto.Field( - proto.MESSAGE, number=1, message=struct.Value, + index_display_name_mapping = proto.Field(proto.MESSAGE, number=1, oneof='display_name_mapping', + message=struct.Value, ) - display_name_mapping_key = proto.Field(proto.STRING, number=2) - inputs = proto.MapField( - proto.STRING, proto.MESSAGE, number=1, message=InputMetadata, + display_name_mapping_key = proto.Field(proto.STRING, number=2, oneof='display_name_mapping') + + inputs = proto.MapField(proto.STRING, proto.MESSAGE, number=1, + message=InputMetadata, ) - outputs = proto.MapField( - proto.STRING, proto.MESSAGE, number=2, message=OutputMetadata, + + outputs = proto.MapField(proto.STRING, proto.MESSAGE, number=2, + message=OutputMetadata, ) + feature_attributions_schema_uri = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py index 171e37ad09..e421cbe615 100644 --- a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py @@ -26,7 +26,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"HyperparameterTuningJob",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'HyperparameterTuningJob', + }, ) @@ -96,21 +99,51 @@ class HyperparameterTuningJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) + + study_spec = proto.Field(proto.MESSAGE, number=4, + message=study.StudySpec, + ) + max_trial_count = proto.Field(proto.INT32, number=5) + parallel_trial_count = proto.Field(proto.INT32, number=6) + max_failed_trial_count = proto.Field(proto.INT32, number=7) - trial_job_spec = proto.Field( - proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, + + trial_job_spec = proto.Field(proto.MESSAGE, number=8, + message=custom_job.CustomJobSpec, + ) + + trials = proto.RepeatedField(proto.MESSAGE, number=9, + message=study.Trial, + ) + + state = proto.Field(proto.ENUM, number=10, + enum=job_state.JobState, ) - trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) + + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) + + start_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) + + end_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) + + error = proto.Field(proto.MESSAGE, number=15, + message=status.Status, + ) + labels = proto.MapField(proto.STRING, proto.STRING, number=16) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index f5fcc170f9..7e47f3e3f7 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -19,13 +19,13 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "GcsSource", - "GcsDestination", - "BigQuerySource", - "BigQueryDestination", - "ContainerRegistryDestination", + 'GcsSource', + 'GcsDestination', + 'BigQuerySource', + 'BigQueryDestination', + 'ContainerRegistryDestination', }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index 98e80c19a2..45c303431a 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -18,46 +18,40 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateCustomJobRequest", - "GetCustomJobRequest", - "ListCustomJobsRequest", - "ListCustomJobsResponse", - "DeleteCustomJobRequest", - "CancelCustomJobRequest", - "CreateDataLabelingJobRequest", - "GetDataLabelingJobRequest", - "ListDataLabelingJobsRequest", - "ListDataLabelingJobsResponse", - "DeleteDataLabelingJobRequest", - "CancelDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "GetHyperparameterTuningJobRequest", - "ListHyperparameterTuningJobsRequest", - "ListHyperparameterTuningJobsResponse", - "DeleteHyperparameterTuningJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "GetBatchPredictionJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "DeleteBatchPredictionJobRequest", - "CancelBatchPredictionJobRequest", + 'CreateCustomJobRequest', + 'GetCustomJobRequest', + 'ListCustomJobsRequest', + 'ListCustomJobsResponse', + 'DeleteCustomJobRequest', + 'CancelCustomJobRequest', + 'CreateDataLabelingJobRequest', + 'GetDataLabelingJobRequest', + 'ListDataLabelingJobsRequest', + 'ListDataLabelingJobsResponse', + 'DeleteDataLabelingJobRequest', + 'CancelDataLabelingJobRequest', + 'CreateHyperparameterTuningJobRequest', + 'GetHyperparameterTuningJobRequest', + 'ListHyperparameterTuningJobsRequest', + 'ListHyperparameterTuningJobsResponse', + 'DeleteHyperparameterTuningJobRequest', + 'CancelHyperparameterTuningJobRequest', + 'CreateBatchPredictionJobRequest', + 'GetBatchPredictionJobRequest', + 'ListBatchPredictionJobsRequest', + 'ListBatchPredictionJobsResponse', + 'DeleteBatchPredictionJobRequest', + 'CancelBatchPredictionJobRequest', }, ) @@ -76,7 +70,10 @@ class CreateCustomJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) + + custom_job = proto.Field(proto.MESSAGE, number=2, + message=gca_custom_job.CustomJob, + ) class GetCustomJobRequest(proto.Message): @@ -132,10 +129,16 @@ class ListCustomJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListCustomJobsResponse(proto.Message): @@ -155,9 +158,10 @@ class ListCustomJobsResponse(proto.Message): def raw_page(self): return self - custom_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, + custom_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_custom_job.CustomJob, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -201,8 +205,9 @@ class CreateDataLabelingJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field( - proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, + + data_labeling_job = proto.Field(proto.MESSAGE, number=2, + message=gca_data_labeling_job.DataLabelingJob, ) @@ -261,10 +266,17 @@ class ListDataLabelingJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) + order_by = proto.Field(proto.STRING, number=6) @@ -284,9 +296,10 @@ class ListDataLabelingJobsResponse(proto.Message): def raw_page(self): return self - data_labeling_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, + data_labeling_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_data_labeling_job.DataLabelingJob, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -334,9 +347,8 @@ class CreateHyperparameterTuningJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - hyperparameter_tuning_job = proto.Field( - proto.MESSAGE, - number=2, + + hyperparameter_tuning_job = proto.Field(proto.MESSAGE, number=2, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -396,10 +408,16 @@ class ListHyperparameterTuningJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListHyperparameterTuningJobsResponse(proto.Message): @@ -421,11 +439,10 @@ class ListHyperparameterTuningJobsResponse(proto.Message): def raw_page(self): return self - hyperparameter_tuning_jobs = proto.RepeatedField( - proto.MESSAGE, - number=1, + hyperparameter_tuning_jobs = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -473,8 +490,9 @@ class CreateBatchPredictionJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - batch_prediction_job = proto.Field( - proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, + + batch_prediction_job = proto.Field(proto.MESSAGE, number=2, + message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -533,10 +551,16 @@ class ListBatchPredictionJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListBatchPredictionJobsResponse(proto.Message): @@ -557,9 +581,10 @@ class ListBatchPredictionJobsResponse(proto.Message): def raw_page(self): return self - batch_prediction_jobs = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_batch_prediction_job.BatchPredictionJob, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/job_state.py b/google/cloud/aiplatform_v1beta1/types/job_state.py index f86e179b1b..f23f7f60cd 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_state.py +++ b/google/cloud/aiplatform_v1beta1/types/job_state.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"JobState",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'JobState', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index 30b81e3efc..88aea166a2 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -18,19 +18,17 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - accelerator_type as gca_accelerator_type, -) +from google.cloud.aiplatform_v1beta1.types import accelerator_type as gca_accelerator_type __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "MachineSpec", - "DedicatedResources", - "AutomaticResources", - "BatchDedicatedResources", - "ResourcesConsumed", + 'MachineSpec', + 'DedicatedResources', + 'AutomaticResources', + 'BatchDedicatedResources', + 'ResourcesConsumed', }, ) @@ -89,9 +87,11 @@ class MachineSpec(proto.Message): """ machine_type = proto.Field(proto.STRING, number=1) - accelerator_type = proto.Field( - proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, + + accelerator_type = proto.Field(proto.ENUM, number=2, + enum=gca_accelerator_type.AcceleratorType, ) + accelerator_count = proto.Field(proto.INT32, number=3) @@ -128,8 +128,12 @@ class DedicatedResources(proto.Message): as the default value. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) + machine_spec = proto.Field(proto.MESSAGE, number=1, + message=MachineSpec, + ) + min_replica_count = proto.Field(proto.INT32, number=2) + max_replica_count = proto.Field(proto.INT32, number=3) @@ -166,6 +170,7 @@ class AutomaticResources(proto.Message): """ min_replica_count = proto.Field(proto.INT32, number=1) + max_replica_count = proto.Field(proto.INT32, number=2) @@ -189,8 +194,12 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) + machine_spec = proto.Field(proto.MESSAGE, number=1, + message=MachineSpec, + ) + starting_replica_count = proto.Field(proto.INT32, number=2) + max_replica_count = proto.Field(proto.INT32, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py index 7a467d5069..da5c4d38ab 100644 --- a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py +++ b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py @@ -19,8 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"ManualBatchTuningParameters",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ManualBatchTuningParameters', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py new file mode 100644 index 0000000000..a96f6d420f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'MigratableResource', + }, +) + + +class MigratableResource(proto.Message): + r"""Represents one resource that exists in automl.googleapis.com, + datalabeling.googleapis.com or ml.googleapis.com. + + Attributes: + ml_engine_model_version (~.migratable_resource.MigratableResource.MlEngineModelVersion): + Output only. Represents one Version in + ml.googleapis.com. + automl_model (~.migratable_resource.MigratableResource.AutomlModel): + Output only. Represents one Model in + automl.googleapis.com. + automl_dataset (~.migratable_resource.MigratableResource.AutomlDataset): + Output only. Represents one Dataset in + automl.googleapis.com. + data_labeling_dataset (~.migratable_resource.MigratableResource.DataLabelingDataset): + Output only. Represents one Dataset in + datalabeling.googleapis.com. + last_migrate_time (~.timestamp.Timestamp): + Output only. Timestamp when last migrate + attempt on this MigratableResource started. Will + not be set if there's no migrate attempt on this + MigratableResource. + last_update_time (~.timestamp.Timestamp): + Output only. Timestamp when this + MigratableResource was last updated. + """ + class MlEngineModelVersion(proto.Message): + r"""Represents one model Version in ml.googleapis.com. + + Attributes: + endpoint (str): + The ml.googleapis.com endpoint that this model Version + currently lives in. Example values: + + - ml.googleapis.com + - us-centrall-ml.googleapis.com + - europe-west4-ml.googleapis.com + - asia-east1-ml.googleapis.com + version (str): + Full resource name of ml engine model Version. Format: + ``projects/{project}/models/{model}/versions/{version}``. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + version = proto.Field(proto.STRING, number=2) + + class AutomlModel(proto.Message): + r"""Represents one Model in automl.googleapis.com. + + Attributes: + model (str): + Full resource name of automl Model. Format: + ``projects/{project}/locations/{location}/models/{model}``. + model_display_name (str): + The Model's display name in + automl.googleapis.com. + """ + + model = proto.Field(proto.STRING, number=1) + + model_display_name = proto.Field(proto.STRING, number=3) + + class AutomlDataset(proto.Message): + r"""Represents one Dataset in automl.googleapis.com. + + Attributes: + dataset (str): + Full resource name of automl Dataset. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}``. + dataset_display_name (str): + The Dataset's display name in + automl.googleapis.com. + """ + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=4) + + class DataLabelingDataset(proto.Message): + r"""Represents one Dataset in datalabeling.googleapis.com. + + Attributes: + dataset (str): + Full resource name of data labeling Dataset. Format: + ``projects/{project}/datasets/{dataset}``. + dataset_display_name (str): + The Dataset's display name in + datalabeling.googleapis.com. + data_labeling_annotated_datasets (Sequence[~.migratable_resource.MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset]): + The migratable AnnotatedDataset in + datalabeling.googleapis.com belongs to the data + labeling Dataset. + """ + class DataLabelingAnnotatedDataset(proto.Message): + r"""Represents one AnnotatedDataset in + datalabeling.googleapis.com. + + Attributes: + annotated_dataset (str): + Full resource name of data labeling AnnotatedDataset. + Format: + + ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. + annotated_dataset_display_name (str): + The AnnotatedDataset's display name in + datalabeling.googleapis.com. + """ + + annotated_dataset = proto.Field(proto.STRING, number=1) + + annotated_dataset_display_name = proto.Field(proto.STRING, number=3) + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=4) + + data_labeling_annotated_datasets = proto.RepeatedField(proto.MESSAGE, number=3, + message='MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset', + ) + + ml_engine_model_version = proto.Field(proto.MESSAGE, number=1, oneof='resource', + message=MlEngineModelVersion, + ) + + automl_model = proto.Field(proto.MESSAGE, number=2, oneof='resource', + message=AutomlModel, + ) + + automl_dataset = proto.Field(proto.MESSAGE, number=3, oneof='resource', + message=AutomlDataset, + ) + + data_labeling_dataset = proto.Field(proto.MESSAGE, number=4, oneof='resource', + message=DataLabelingDataset, + ) + + last_migrate_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) + + last_update_time = proto.Field(proto.MESSAGE, number=6, + message=timestamp.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py new file mode 100644 index 0000000000..607629f06a --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import migratable_resource as gca_migratable_resource +from google.cloud.aiplatform_v1beta1.types import operation + + +__protobuf__ = proto.module( + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'SearchMigratableResourcesRequest', + 'SearchMigratableResourcesResponse', + 'BatchMigrateResourcesRequest', + 'MigrateResourceRequest', + 'BatchMigrateResourcesResponse', + 'MigrateResourceResponse', + 'BatchMigrateResourcesOperationMetadata', + }, +) + + +class SearchMigratableResourcesRequest(proto.Message): + r"""Request message for + ``MigrationService.SearchMigratableResources``. + + Attributes: + parent (str): + Required. The location that the migratable resources should + be searched from. It's the AI Platform location that the + resources can be migrated to, not the resources' original + location. Format: + ``projects/{project}/locations/{location}`` + page_size (int): + The standard page size. + The default and maximum value is 100. + page_token (str): + The standard page token. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + +class SearchMigratableResourcesResponse(proto.Message): + r"""Response message for + ``MigrationService.SearchMigratableResources``. + + Attributes: + migratable_resources (Sequence[~.gca_migratable_resource.MigratableResource]): + All migratable resources that can be migrated + to the location specified in the request. + next_page_token (str): + The standard next-page token. The migratable_resources may + not fill page_size in SearchMigratableResourcesRequest even + when there are subsequent pages. + """ + + @property + def raw_page(self): + return self + + migratable_resources = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_migratable_resource.MigratableResource, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class BatchMigrateResourcesRequest(proto.Message): + r"""Request message for + ``MigrationService.BatchMigrateResources``. + + Attributes: + parent (str): + Required. The location of the migrated resource will live + in. Format: ``projects/{project}/locations/{location}`` + migrate_resource_requests (Sequence[~.migration_service.MigrateResourceRequest]): + Required. The request messages specifying the + resources to migrate. They must be in the same + location as the destination. Up to 50 resources + can be migrated in one batch. + """ + + parent = proto.Field(proto.STRING, number=1) + + migrate_resource_requests = proto.RepeatedField(proto.MESSAGE, number=2, + message='MigrateResourceRequest', + ) + + +class MigrateResourceRequest(proto.Message): + r"""Config of migrating one resource from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + Attributes: + migrate_ml_engine_model_version_config (~.migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig): + Config for migrating Version in + ml.googleapis.com to AI Platform's Model. + migrate_automl_model_config (~.migration_service.MigrateResourceRequest.MigrateAutomlModelConfig): + Config for migrating Model in + automl.googleapis.com to AI Platform's Model. + migrate_automl_dataset_config (~.migration_service.MigrateResourceRequest.MigrateAutomlDatasetConfig): + Config for migrating Dataset in + automl.googleapis.com to AI Platform's Dataset. + migrate_data_labeling_dataset_config (~.migration_service.MigrateResourceRequest.MigrateDataLabelingDatasetConfig): + Config for migrating Dataset in + datalabeling.googleapis.com to AI Platform's + Dataset. + """ + class MigrateMlEngineModelVersionConfig(proto.Message): + r"""Config for migrating version in ml.googleapis.com to AI + Platform's Model. + + Attributes: + endpoint (str): + Required. The ml.googleapis.com endpoint that this model + version should be migrated from. Example values: + + - ml.googleapis.com + + - us-centrall-ml.googleapis.com + + - europe-west4-ml.googleapis.com + + - asia-east1-ml.googleapis.com + model_version (str): + Required. Full resource name of ml engine model version. + Format: + ``projects/{project}/models/{model}/versions/{version}``. + model_display_name (str): + Required. Display name of the model in AI + Platform. System will pick a display name if + unspecified. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + model_version = proto.Field(proto.STRING, number=2) + + model_display_name = proto.Field(proto.STRING, number=3) + + class MigrateAutomlModelConfig(proto.Message): + r"""Config for migrating Model in automl.googleapis.com to AI + Platform's Model. + + Attributes: + model (str): + Required. Full resource name of automl Model. Format: + ``projects/{project}/locations/{location}/models/{model}``. + model_display_name (str): + Optional. Display name of the model in AI + Platform. System will pick a display name if + unspecified. + """ + + model = proto.Field(proto.STRING, number=1) + + model_display_name = proto.Field(proto.STRING, number=2) + + class MigrateAutomlDatasetConfig(proto.Message): + r"""Config for migrating Dataset in automl.googleapis.com to AI + Platform's Dataset. + + Attributes: + dataset (str): + Required. Full resource name of automl Dataset. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}``. + dataset_display_name (str): + Required. Display name of the Dataset in AI + Platform. System will pick a display name if + unspecified. + """ + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=2) + + class MigrateDataLabelingDatasetConfig(proto.Message): + r"""Config for migrating Dataset in datalabeling.googleapis.com + to AI Platform's Dataset. + + Attributes: + dataset (str): + Required. Full resource name of data labeling Dataset. + Format: ``projects/{project}/datasets/{dataset}``. + dataset_display_name (str): + Optional. Display name of the Dataset in AI + Platform. System will pick a display name if + unspecified. + migrate_data_labeling_annotated_dataset_configs (Sequence[~.migration_service.MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig]): + Optional. Configs for migrating + AnnotatedDataset in datalabeling.googleapis.com + to AI Platform's SavedQuery. The specified + AnnotatedDatasets have to belong to the + datalabeling Dataset. + """ + class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): + r"""Config for migrating AnnotatedDataset in + datalabeling.googleapis.com to AI Platform's SavedQuery. + + Attributes: + annotated_dataset (str): + Required. Full resource name of data labeling + AnnotatedDataset. Format: + + ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. + """ + + annotated_dataset = proto.Field(proto.STRING, number=1) + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=2) + + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField(proto.MESSAGE, number=3, + message='MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig', + ) + + migrate_ml_engine_model_version_config = proto.Field(proto.MESSAGE, number=1, oneof='request', + message=MigrateMlEngineModelVersionConfig, + ) + + migrate_automl_model_config = proto.Field(proto.MESSAGE, number=2, oneof='request', + message=MigrateAutomlModelConfig, + ) + + migrate_automl_dataset_config = proto.Field(proto.MESSAGE, number=3, oneof='request', + message=MigrateAutomlDatasetConfig, + ) + + migrate_data_labeling_dataset_config = proto.Field(proto.MESSAGE, number=4, oneof='request', + message=MigrateDataLabelingDatasetConfig, + ) + + +class BatchMigrateResourcesResponse(proto.Message): + r"""Response message for + ``MigrationService.BatchMigrateResources``. + + Attributes: + migrate_resource_responses (Sequence[~.migration_service.MigrateResourceResponse]): + Successfully migrated resources. + """ + + migrate_resource_responses = proto.RepeatedField(proto.MESSAGE, number=1, + message='MigrateResourceResponse', + ) + + +class MigrateResourceResponse(proto.Message): + r"""Describes a successfully migrated resource. + + Attributes: + dataset (str): + Migrated Dataset's resource name. + model (str): + Migrated Model's resource name. + migratable_resource (~.gca_migratable_resource.MigratableResource): + Before migration, the identifier in + ml.googleapis.com, automl.googleapis.com or + datalabeling.googleapis.com. + """ + + dataset = proto.Field(proto.STRING, number=1, oneof='migrated_resource') + + model = proto.Field(proto.STRING, number=2, oneof='migrated_resource') + + migratable_resource = proto.Field(proto.MESSAGE, number=3, + message=gca_migratable_resource.MigratableResource, + ) + + +class BatchMigrateResourcesOperationMetadata(proto.Message): + r"""Runtime operation information for + ``MigrationService.BatchMigrateResources``. + + Attributes: + generic_metadata (~.operation.GenericOperationMetadata): + The common part of the operation metadata. + """ + + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 39cb44206e..abd5f67e94 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -26,8 +26,13 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"Model", "PredictSchemata", "ModelContainerSpec", "Port",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Model', + 'PredictSchemata', + 'ModelContainerSpec', + 'Port', + }, ) @@ -227,7 +232,6 @@ class Model(proto.Message): See https://goo.gl/xmQnxf for more information and examples of labels. """ - class DeploymentResourcesType(proto.Enum): r"""Identifies a type of Model's prediction resources.""" DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = 0 @@ -264,7 +268,6 @@ class ExportFormat(proto.Message): Output only. The content of this Model that may be exported. """ - class ExportableContent(proto.Enum): r"""The Model content that can be exported.""" EXPORTABLE_CONTENT_UNSPECIFIED = 0 @@ -272,36 +275,65 @@ class ExportableContent(proto.Enum): IMAGE = 2 id = proto.Field(proto.STRING, number=1) - exportable_contents = proto.RepeatedField( - proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", + + exportable_contents = proto.RepeatedField(proto.ENUM, number=2, + enum='Model.ExportFormat.ExportableContent', ) name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + description = proto.Field(proto.STRING, number=3) - predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) + + predict_schemata = proto.Field(proto.MESSAGE, number=4, + message='PredictSchemata', + ) + metadata_schema_uri = proto.Field(proto.STRING, number=5) - metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - supported_export_formats = proto.RepeatedField( - proto.MESSAGE, number=20, message=ExportFormat, + + metadata = proto.Field(proto.MESSAGE, number=6, + message=struct.Value, + ) + + supported_export_formats = proto.RepeatedField(proto.MESSAGE, number=20, + message=ExportFormat, ) + training_pipeline = proto.Field(proto.STRING, number=7) - container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) + + container_spec = proto.Field(proto.MESSAGE, number=9, + message='ModelContainerSpec', + ) + artifact_uri = proto.Field(proto.STRING, number=26) - supported_deployment_resources_types = proto.RepeatedField( - proto.ENUM, number=10, enum=DeploymentResourcesType, + + supported_deployment_resources_types = proto.RepeatedField(proto.ENUM, number=10, + enum=DeploymentResourcesType, ) + supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) + supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) - create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - deployed_models = proto.RepeatedField( - proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, + + create_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, ) - explanation_spec = proto.Field( - proto.MESSAGE, number=23, message=explanation.ExplanationSpec, + + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) + + deployed_models = proto.RepeatedField(proto.MESSAGE, number=15, + message=deployed_model_ref.DeployedModelRef, ) + + explanation_spec = proto.Field(proto.MESSAGE, number=23, + message=explanation.ExplanationSpec, + ) + etag = proto.Field(proto.STRING, number=16) + labels = proto.MapField(proto.STRING, proto.STRING, number=17) @@ -363,75 +395,278 @@ class PredictSchemata(proto.Message): """ instance_schema_uri = proto.Field(proto.STRING, number=1) + parameters_schema_uri = proto.Field(proto.STRING, number=2) + prediction_schema_uri = proto.Field(proto.STRING, number=3) class ModelContainerSpec(proto.Message): - r"""Specification of the container to be deployed for this Model. The - ModelContainerSpec is based on the Kubernetes Container - `specification `__. + r"""Specification of a container for serving predictions. This message + is a subset of the [Kubernetes Container v1 core + + specification](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). Attributes: image_uri (str): - Required. Immutable. The URI of the Model serving container - file in the Container Registry. The container image is - ingested upon + Required. Immutable. URI of the Docker image to be used as + the custom container for serving predictions. This URI must + identify an image in Artifact Registry or Container + Registry. Learn more about the [container publishing + + requirements](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#publishing), + including permissions requirements for the AI Platform + Service Agent. + + The container image is ingested upon ``ModelService.UploadModel``, stored internally, and this original path is afterwards not used. + + To learn about the requirements for the Docker image itself, + read [Custom container + + requirements](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements). command (Sequence[str]): - Immutable. The command with which the container is run. Not - executed within a shell. The Docker image's ENTRYPOINT is - used if this is not provided. Variable references - $(VAR_NAME) are expanded using the container's environment. - If a variable cannot be resolved, the reference in the input - string will be unchanged. The $(VAR_NAME) syntax can be - escaped with a double $$, ie: $$(VAR_NAME). Escaped - references will never be expanded, regardless of whether the - variable exists or not. More info: - https://tinyurl.com/y42hmlxe + Immutable. Specifies the command that runs when the + container starts. This overrides the container's + + [``ENTRYPOINT``](https://docs.docker.com/engine/reference/builder/#entrypoint). + Specify this field as an array of executable and arguments, + similar to a Docker ``ENTRYPOINT``'s "exec" form, not its + "shell" form. + + If you do not specify this field, then the container's + ``ENTRYPOINT`` runs, in conjunction with the + ``args`` + field or the container's + ```CMD`` `__, + if either exists. If this field is not specified and the + container does not have an ``ENTRYPOINT``, then refer to the + [Docker documentation about how ``CMD`` and ``ENTRYPOINT`` + + interact](https://docs.docker.com/engine/reference/builder/#understand-how-cmd-and-entrypoint-interact). + + If you specify this field, then you can also specify the + ``args`` field to provide additional arguments for this + command. However, if you specify this field, then the + container's ``CMD`` is ignored. See the [Kubernetes + documentation about how the ``command`` and ``args`` fields + interact with a container's ``ENTRYPOINT`` and + + ``CMD``](https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#notes). + + In this field, you can reference [environment variables set + by AI + + Platform](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables) + and environment variables set in the + ``env`` + field. You cannot reference environment variables set in the + Docker image. In order for environment variables to be + expanded, reference them by using the following syntax: + $(VARIABLE_NAME) Note that this differs from Bash variable + expansion, which does not use parentheses. If a variable + cannot be resolved, the reference in the input string is + used unchanged. To avoid variable expansion, you can escape + this syntax with ``$$``; for example: $$(VARIABLE_NAME) This + field corresponds to the ``command`` field of the + [Kubernetes Containers v1 core + + API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). args (Sequence[str]): - Immutable. The arguments to the command. The Docker image's - CMD is used if this is not provided. Variable references - $(VAR_NAME) are expanded using the container's environment. - If a variable cannot be resolved, the reference in the input - string will be unchanged. The $(VAR_NAME) syntax can be - escaped with a double $$, ie: $$(VAR_NAME). Escaped - references will never be expanded, regardless of whether the - variable exists or not. More info: - https://tinyurl.com/y42hmlxe + Immutable. Specifies arguments for the command that runs + when the container starts. This overrides the container's + ```CMD`` `__. + Specify this field as an array of executable and arguments, + similar to a Docker ``CMD``'s "default parameters" form. + + If you don't specify this field but do specify the + ``command`` + field, then the command from the ``command`` field runs + without any additional arguments. See the [Kubernetes + documentation about how the ``command`` and ``args`` fields + interact with a container's ``ENTRYPOINT`` and + + ``CMD``](https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#notes). + + If you don't specify this field and don't specify the + ``command`` field, then the container's + ```ENTRYPOINT`` `__ + and ``CMD`` determine what runs based on their default + behavior. See the [Docker documentation about how ``CMD`` + and ``ENTRYPOINT`` + + interact](https://docs.docker.com/engine/reference/builder/#understand-how-cmd-and-entrypoint-interact). + + In this field, you can reference [environment variables set + by AI + + Platform](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables) + and environment variables set in the + ``env`` + field. You cannot reference environment variables set in the + Docker image. In order for environment variables to be + expanded, reference them by using the following syntax: + $(VARIABLE_NAME) Note that this differs from Bash variable + expansion, which does not use parentheses. If a variable + cannot be resolved, the reference in the input string is + used unchanged. To avoid variable expansion, you can escape + this syntax with ``$$``; for example: $$(VARIABLE_NAME) This + field corresponds to the ``args`` field of the [Kubernetes + Containers v1 core + + API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). env (Sequence[~.env_var.EnvVar]): - Immutable. The environment variables that are - to be present in the container. + Immutable. List of environment variables to set in the + container. After the container starts running, code running + in the container can read these environment variables. + + Additionally, the + ``command`` + and + ``args`` + fields can reference these variables. Later entries in this + list can also reference earlier entries. For example, the + following example sets the variable ``VAR_2`` to have the + value ``foo bar``: + + .. code:: json + + [ + { + "name": "VAR_1", + "value": "foo" + }, + { + "name": "VAR_2", + "value": "$(VAR_1) bar" + } + ] + + If you switch the order of the variables in the example, + then the expansion does not occur. + + This field corresponds to the ``env`` field of the + [Kubernetes Containers v1 core + + API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). ports (Sequence[~.model.Port]): - Immutable. Declaration of ports that are - exposed by the container. This field is - primarily informational, it gives AI Platform - information about the network connections the - container uses. Listing or not a port here has - no impact on whether the port is actually - exposed, any port listening on the default - "0.0.0.0" address inside a container will be - accessible from the network. + Immutable. List of ports to expose from the container. AI + Platform sends any prediction requests that it receives to + the first port on this list. AI Platform also sends + [liveness and health + + checks](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#health) to + this port. + + If you do not specify this field, it defaults to following + value: + + .. code:: json + + [ + { + "containerPort": 8080 + } + ] + + AI Platform does not use ports other than the first one + listed. This field corresponds to the ``ports`` field of the + [Kubernetes Containers v1 core + + API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). predict_route (str): - Immutable. An HTTP path to send prediction - requests to the container, and which must be - supported by it. If not specified a default HTTP - path will be used by AI Platform. + Immutable. HTTP path on the container to send prediction + requests to. AI Platform forwards requests sent using + ``projects.locations.endpoints.predict`` + to this path on the container's IP address and port. AI + Platform then returns the container's response in the API + response. + + For example, if you set this field to ``/foo``, then when AI + Platform receives a prediction request, it forwards the + request body in a POST request to the following URL on the + container: localhost:PORT/foo PORT refers to the first value + of this ``ModelContainerSpec``'s + ``ports`` + field. + + If you don't specify this field, it defaults to the + following value when you [deploy this Model to an + Endpoint][google.cloud.aiplatform.v1beta1.EndpointService.DeployModel]: + /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict + The placeholders in this value are replaced as follows: + + - ENDPOINT: The last segment (following ``endpoints/``)of + the Endpoint.name][] field of the Endpoint where this + Model has been deployed. (AI Platform makes this value + available to your container code as the + [``AIP_ENDPOINT_ID`` environment + + variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) + + - DEPLOYED_MODEL: + ``DeployedModel.id`` + of the ``DeployedModel``. (AI Platform makes this value + available to your container code as the + [``AIP_DEPLOYED_MODEL_ID`` environment + + variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) health_route (str): - Immutable. An HTTP path to send health check - requests to the container, and which must be - supported by it. If not specified a standard - HTTP path will be used by AI Platform. + Immutable. HTTP path on the container to send health checkss + to. AI Platform intermittently sends GET requests to this + path on the container's IP address and port to check that + the container is healthy. Read more about [health + + checks](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#checks). + + For example, if you set this field to ``/bar``, then AI + Platform intermittently sends a GET request to the following + URL on the container: localhost:PORT/bar PORT refers to the + first value of this ``ModelContainerSpec``'s + ``ports`` + field. + + If you don't specify this field, it defaults to the + following value when you [deploy this Model to an + Endpoint][google.cloud.aiplatform.v1beta1.EndpointService.DeployModel]: + /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict + The placeholders in this value are replaced as follows: + + - ENDPOINT: The last segment (following ``endpoints/``)of + the Endpoint.name][] field of the Endpoint where this + Model has been deployed. (AI Platform makes this value + available to your container code as the + [``AIP_ENDPOINT_ID`` environment + + variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) + + - DEPLOYED_MODEL: + ``DeployedModel.id`` + of the ``DeployedModel``. (AI Platform makes this value + available to your container code as the + [``AIP_DEPLOYED_MODEL_ID`` environment + + variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) """ image_uri = proto.Field(proto.STRING, number=1) + command = proto.RepeatedField(proto.STRING, number=2) + args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) - ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) + + env = proto.RepeatedField(proto.MESSAGE, number=4, + message=env_var.EnvVar, + ) + + ports = proto.RepeatedField(proto.MESSAGE, number=5, + message='Port', + ) + predict_route = proto.Field(proto.STRING, number=6) + health_route = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py index 13f49e963e..5613b3017d 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py @@ -24,7 +24,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluation",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ModelEvaluation', + }, ) @@ -68,12 +71,21 @@ class ModelEvaluation(proto.Message): """ name = proto.Field(proto.STRING, number=1) + metrics_schema_uri = proto.Field(proto.STRING, number=2) - metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + metrics = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) + + create_time = proto.Field(proto.MESSAGE, number=4, + message=timestamp.Timestamp, + ) + slice_dimensions = proto.RepeatedField(proto.STRING, number=5) - model_explanation = proto.Field( - proto.MESSAGE, number=8, message=explanation.ModelExplanation, + + model_explanation = proto.Field(proto.MESSAGE, number=8, + message=explanation.ModelExplanation, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py index fe8dc19754..7d21157f1a 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py @@ -23,7 +23,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluationSlice",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'ModelEvaluationSlice', + }, ) @@ -36,7 +39,7 @@ class ModelEvaluationSlice(proto.Message): name (str): Output only. The resource name of the ModelEvaluationSlice. - slice (~.model_evaluation_slice.ModelEvaluationSlice.Slice): + slice_ (~.model_evaluation_slice.ModelEvaluationSlice.Slice): Output only. The slice of the test data that is used to evaluate the Model. metrics_schema_uri (str): @@ -54,7 +57,6 @@ class ModelEvaluationSlice(proto.Message): Output only. Timestamp when this ModelEvaluationSlice was created. """ - class Slice(proto.Message): r"""Definition of a slice. @@ -74,13 +76,24 @@ class Slice(proto.Message): """ dimension = proto.Field(proto.STRING, number=1) + value = proto.Field(proto.STRING, number=2) name = proto.Field(proto.STRING, number=1) - slice = proto.Field(proto.MESSAGE, number=2, message=Slice,) + + slice_ = proto.Field(proto.MESSAGE, number=2, + message=Slice, + ) + metrics_schema_uri = proto.Field(proto.STRING, number=3) - metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + metrics = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) + + create_time = proto.Field(proto.MESSAGE, number=5, + message=timestamp.Timestamp, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index 42706490c1..e5945e49c0 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -27,25 +27,25 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "UploadModelRequest", - "UploadModelOperationMetadata", - "UploadModelResponse", - "GetModelRequest", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", - "DeleteModelRequest", - "ExportModelRequest", - "ExportModelOperationMetadata", - "ExportModelResponse", - "GetModelEvaluationRequest", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", - "GetModelEvaluationSliceRequest", - "ListModelEvaluationSlicesRequest", - "ListModelEvaluationSlicesResponse", + 'UploadModelRequest', + 'UploadModelOperationMetadata', + 'UploadModelResponse', + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'UpdateModelRequest', + 'DeleteModelRequest', + 'ExportModelRequest', + 'ExportModelOperationMetadata', + 'ExportModelResponse', + 'GetModelEvaluationRequest', + 'ListModelEvaluationsRequest', + 'ListModelEvaluationsResponse', + 'GetModelEvaluationSliceRequest', + 'ListModelEvaluationSlicesRequest', + 'ListModelEvaluationSlicesResponse', }, ) @@ -64,7 +64,10 @@ class UploadModelRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) + + model = proto.Field(proto.MESSAGE, number=2, + message=gca_model.Model, + ) class UploadModelOperationMetadata(proto.Message): @@ -77,8 +80,8 @@ class UploadModelOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -133,10 +136,16 @@ class ListModelsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelsResponse(proto.Message): @@ -156,7 +165,10 @@ class ListModelsResponse(proto.Message): def raw_page(self): return self - models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) + models = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_model.Model, + ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -175,8 +187,13 @@ class UpdateModelRequest(proto.Message): [FieldMask](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#fieldmask). """ - model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + model = proto.Field(proto.MESSAGE, number=1, + message=gca_model.Model, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, + ) class DeleteModelRequest(proto.Message): @@ -205,7 +222,6 @@ class ExportModelRequest(proto.Message): Required. The desired output location and configuration. """ - class OutputConfig(proto.Message): r"""Output configuration for the Model export. @@ -236,15 +252,20 @@ class OutputConfig(proto.Message): """ export_format_id = proto.Field(proto.STRING, number=1) - artifact_destination = proto.Field( - proto.MESSAGE, number=3, message=io.GcsDestination, + + artifact_destination = proto.Field(proto.MESSAGE, number=3, + message=io.GcsDestination, ) - image_destination = proto.Field( - proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, + + image_destination = proto.Field(proto.MESSAGE, number=4, + message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) - output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) + + output_config = proto.Field(proto.MESSAGE, number=2, + message=OutputConfig, + ) class ExportModelOperationMetadata(proto.Message): @@ -259,7 +280,6 @@ class ExportModelOperationMetadata(proto.Message): Output only. Information further describing the output of this Model export. """ - class OutputInfo(proto.Message): r"""Further describes the output of the ExportModel. Supplements ``ExportModelRequest.OutputConfig``. @@ -278,12 +298,16 @@ class OutputInfo(proto.Message): """ artifact_output_uri = proto.Field(proto.STRING, number=2) + image_output_uri = proto.Field(proto.STRING, number=3) - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, + ) + + output_info = proto.Field(proto.MESSAGE, number=2, + message=OutputInfo, ) - output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) class ExportModelResponse(proto.Message): @@ -331,10 +355,16 @@ class ListModelEvaluationsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelEvaluationsResponse(proto.Message): @@ -355,9 +385,10 @@ class ListModelEvaluationsResponse(proto.Message): def raw_page(self): return self - model_evaluations = proto.RepeatedField( - proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, + model_evaluations = proto.RepeatedField(proto.MESSAGE, number=1, + message=model_evaluation.ModelEvaluation, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -403,10 +434,16 @@ class ListModelEvaluationSlicesRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListModelEvaluationSlicesResponse(proto.Message): @@ -427,9 +464,10 @@ class ListModelEvaluationSlicesResponse(proto.Message): def raw_page(self): return self - model_evaluation_slices = proto.RepeatedField( - proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, + model_evaluation_slices = proto.RepeatedField(proto.MESSAGE, number=1, + message=model_evaluation_slice.ModelEvaluationSlice, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index 3451fb9c8c..bf2a5906bd 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -23,8 +23,11 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"GenericOperationMetadata", "DeleteOperationMetadata",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'GenericOperationMetadata', + 'DeleteOperationMetadata', + }, ) @@ -46,11 +49,17 @@ class GenericOperationMetadata(proto.Message): updated for the last time. """ - partial_failures = proto.RepeatedField( - proto.MESSAGE, number=1, message=status.Status, + partial_failures = proto.RepeatedField(proto.MESSAGE, number=1, + message=status.Status, + ) + + create_time = proto.Field(proto.MESSAGE, number=2, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=3, + message=timestamp.Timestamp, ) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) class DeleteOperationMetadata(proto.Message): @@ -61,8 +70,8 @@ class DeleteOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index 7ba3638c51..727855e58a 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -18,21 +18,19 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateTrainingPipelineRequest", - "GetTrainingPipelineRequest", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "DeleteTrainingPipelineRequest", - "CancelTrainingPipelineRequest", + 'CreateTrainingPipelineRequest', + 'GetTrainingPipelineRequest', + 'ListTrainingPipelinesRequest', + 'ListTrainingPipelinesResponse', + 'DeleteTrainingPipelineRequest', + 'CancelTrainingPipelineRequest', }, ) @@ -51,8 +49,9 @@ class CreateTrainingPipelineRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - training_pipeline = proto.Field( - proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, + + training_pipeline = proto.Field(proto.MESSAGE, number=2, + message=gca_training_pipeline.TrainingPipeline, ) @@ -108,10 +107,16 @@ class ListTrainingPipelinesRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=5, + message=field_mask.FieldMask, + ) class ListTrainingPipelinesResponse(proto.Message): @@ -132,9 +137,10 @@ class ListTrainingPipelinesResponse(proto.Message): def raw_page(self): return self - training_pipelines = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, + training_pipelines = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_training_pipeline.TrainingPipeline, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py index cede653bd6..b04954f602 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"PipelineState",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'PipelineState', + }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index e937a6d8e4..872990a5f1 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -23,12 +23,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "PredictRequest", - "PredictResponse", - "ExplainRequest", - "ExplainResponse", + 'PredictRequest', + 'PredictResponse', + 'ExplainRequest', + 'ExplainResponse', }, ) @@ -64,8 +64,14 @@ class PredictRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + + instances = proto.RepeatedField(proto.MESSAGE, number=2, + message=struct.Value, + ) + + parameters = proto.Field(proto.MESSAGE, number=3, + message=struct.Value, + ) class PredictResponse(proto.Message): @@ -85,7 +91,10 @@ class PredictResponse(proto.Message): served this prediction. """ - predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) + predictions = proto.RepeatedField(proto.MESSAGE, number=1, + message=struct.Value, + ) + deployed_model_id = proto.Field(proto.STRING, number=2) @@ -124,8 +133,15 @@ class ExplainRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + + instances = proto.RepeatedField(proto.MESSAGE, number=2, + message=struct.Value, + ) + + parameters = proto.Field(proto.MESSAGE, number=4, + message=struct.Value, + ) + deployed_model_id = proto.Field(proto.STRING, number=3) @@ -135,8 +151,8 @@ class ExplainResponse(proto.Message): Attributes: explanations (Sequence[~.explanation.Explanation]): - The explanations of the [Model's - predictions][PredictionResponse.predictions][]. + The explanations of the Model's + ``PredictResponse.predictions``. It has the same number of elements as ``instances`` @@ -146,9 +162,10 @@ class ExplainResponse(proto.Message): served this explanation. """ - explanations = proto.RepeatedField( - proto.MESSAGE, number=1, message=explanation.Explanation, + explanations = proto.RepeatedField(proto.MESSAGE, number=1, + message=explanation.Explanation, ) + deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py index 21ab5f9c47..f75416157b 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"SpecialistPool",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'SpecialistPool', + }, ) @@ -55,9 +58,13 @@ class SpecialistPool(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + specialist_managers_count = proto.Field(proto.INT32, number=3) + specialist_manager_emails = proto.RepeatedField(proto.STRING, number=4) + pending_data_labeling_jobs = proto.RepeatedField(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index 02f0dac96f..8ee901a444 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -24,16 +24,16 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "CreateSpecialistPoolRequest", - "CreateSpecialistPoolOperationMetadata", - "GetSpecialistPoolRequest", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "DeleteSpecialistPoolRequest", - "UpdateSpecialistPoolRequest", - "UpdateSpecialistPoolOperationMetadata", + 'CreateSpecialistPoolRequest', + 'CreateSpecialistPoolOperationMetadata', + 'GetSpecialistPoolRequest', + 'ListSpecialistPoolsRequest', + 'ListSpecialistPoolsResponse', + 'DeleteSpecialistPoolRequest', + 'UpdateSpecialistPoolRequest', + 'UpdateSpecialistPoolOperationMetadata', }, ) @@ -52,8 +52,9 @@ class CreateSpecialistPoolRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - specialist_pool = proto.Field( - proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, + + specialist_pool = proto.Field(proto.MESSAGE, number=2, + message=gca_specialist_pool.SpecialistPool, ) @@ -66,8 +67,8 @@ class CreateSpecialistPoolOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + generic_metadata = proto.Field(proto.MESSAGE, number=1, + message=operation.GenericOperationMetadata, ) @@ -108,9 +109,14 @@ class ListSpecialistPoolsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + page_size = proto.Field(proto.INT32, number=2) + page_token = proto.Field(proto.STRING, number=3) - read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) + + read_mask = proto.Field(proto.MESSAGE, number=4, + message=field_mask.FieldMask, + ) class ListSpecialistPoolsResponse(proto.Message): @@ -129,9 +135,10 @@ class ListSpecialistPoolsResponse(proto.Message): def raw_page(self): return self - specialist_pools = proto.RepeatedField( - proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + specialist_pools = proto.RepeatedField(proto.MESSAGE, number=1, + message=gca_specialist_pool.SpecialistPool, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -152,6 +159,7 @@ class DeleteSpecialistPoolRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) + force = proto.Field(proto.BOOL, number=2) @@ -168,10 +176,13 @@ class UpdateSpecialistPoolRequest(proto.Message): resource. """ - specialist_pool = proto.Field( - proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field(proto.MESSAGE, number=1, + message=gca_specialist_pool.SpecialistPool, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, + message=field_mask.FieldMask, ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateSpecialistPoolOperationMetadata(proto.Message): @@ -189,8 +200,9 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): """ specialist_pool = proto.Field(proto.STRING, number=1) - generic_metadata = proto.Field( - proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, + + generic_metadata = proto.Field(proto.MESSAGE, number=2, + message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 55f20127ef..b60e344617 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -23,8 +23,12 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", - manifest={"Trial", "StudySpec", "Measurement",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'Trial', + 'StudySpec', + 'Measurement', + }, ) @@ -54,7 +58,6 @@ class Trial(proto.Message): Trial. It's set for a HyperparameterTuningJob's Trial. """ - class State(proto.Enum): r"""Describes a Trial state.""" STATE_UNSPECIFIED = 0 @@ -81,14 +84,33 @@ class Parameter(proto.Message): """ parameter_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) + + value = proto.Field(proto.MESSAGE, number=2, + message=struct.Value, + ) id = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=3, enum=State,) - parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) - final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + + state = proto.Field(proto.ENUM, number=3, + enum=State, + ) + + parameters = proto.RepeatedField(proto.MESSAGE, number=4, + message=Parameter, + ) + + final_measurement = proto.Field(proto.MESSAGE, number=5, + message='Measurement', + ) + + start_time = proto.Field(proto.MESSAGE, number=7, + message=timestamp.Timestamp, + ) + + end_time = proto.Field(proto.MESSAGE, number=8, + message=timestamp.Timestamp, + ) + custom_job = proto.Field(proto.STRING, number=11) @@ -103,7 +125,6 @@ class StudySpec(proto.Message): algorithm (~.study.StudySpec.Algorithm): The search algorithm specified for the Study. """ - class Algorithm(proto.Enum): r"""The available search algorithms for the Study.""" ALGORITHM_UNSPECIFIED = 0 @@ -122,7 +143,6 @@ class MetricSpec(proto.Message): Required. The optimization goal of the metric. """ - class GoalType(proto.Enum): r"""The available types of optimization goals.""" GOAL_TYPE_UNSPECIFIED = 0 @@ -130,7 +150,10 @@ class GoalType(proto.Enum): MINIMIZE = 2 metric_id = proto.Field(proto.STRING, number=1) - goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) + + goal = proto.Field(proto.ENUM, number=2, + enum='StudySpec.MetricSpec.GoalType', + ) class ParameterSpec(proto.Message): r"""Represents a single parameter to optimize. @@ -152,7 +175,6 @@ class ParameterSpec(proto.Message): How the parameter should be scaled. Leave unset for ``CATEGORICAL`` parameters. """ - class ScaleType(proto.Enum): r"""The type of scaling that should be applied to this parameter.""" SCALE_TYPE_UNSPECIFIED = 0 @@ -173,6 +195,7 @@ class DoubleValueSpec(proto.Message): """ min_value = proto.Field(proto.DOUBLE, number=1) + max_value = proto.Field(proto.DOUBLE, number=2) class IntegerValueSpec(proto.Message): @@ -188,6 +211,7 @@ class IntegerValueSpec(proto.Message): """ min_value = proto.Field(proto.INT64, number=1) + max_value = proto.Field(proto.INT64, number=2) class CategoricalValueSpec(proto.Message): @@ -215,30 +239,39 @@ class DiscreteValueSpec(proto.Message): values = proto.RepeatedField(proto.DOUBLE, number=1) - double_value_spec = proto.Field( - proto.MESSAGE, number=2, message="StudySpec.ParameterSpec.DoubleValueSpec", + double_value_spec = proto.Field(proto.MESSAGE, number=2, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.DoubleValueSpec', ) - integer_value_spec = proto.Field( - proto.MESSAGE, number=3, message="StudySpec.ParameterSpec.IntegerValueSpec", + + integer_value_spec = proto.Field(proto.MESSAGE, number=3, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.IntegerValueSpec', ) - categorical_value_spec = proto.Field( - proto.MESSAGE, - number=4, - message="StudySpec.ParameterSpec.CategoricalValueSpec", + + categorical_value_spec = proto.Field(proto.MESSAGE, number=4, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.CategoricalValueSpec', ) - discrete_value_spec = proto.Field( - proto.MESSAGE, - number=5, - message="StudySpec.ParameterSpec.DiscreteValueSpec", + + discrete_value_spec = proto.Field(proto.MESSAGE, number=5, oneof='parameter_value_spec', + message='StudySpec.ParameterSpec.DiscreteValueSpec', ) + parameter_id = proto.Field(proto.STRING, number=1) - scale_type = proto.Field( - proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", + + scale_type = proto.Field(proto.ENUM, number=6, + enum='StudySpec.ParameterSpec.ScaleType', ) - metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) - parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) - algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, + message=MetricSpec, + ) + + parameters = proto.RepeatedField(proto.MESSAGE, number=2, + message=ParameterSpec, + ) + + algorithm = proto.Field(proto.ENUM, number=3, + enum=Algorithm, + ) class Measurement(proto.Message): @@ -256,7 +289,6 @@ class Measurement(proto.Message): evaluating the objective functions using suggested Parameter values. """ - class Metric(proto.Message): r"""A message representing a metric in the measurement. @@ -270,10 +302,14 @@ class Metric(proto.Message): """ metric_id = proto.Field(proto.STRING, number=1) + value = proto.Field(proto.DOUBLE, number=2) step_count = proto.Field(proto.INT64, number=2) - metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) + + metrics = proto.RepeatedField(proto.MESSAGE, number=3, + message=Metric, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index bb32b7b787..0729605971 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -27,14 +27,14 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", + package='google.cloud.aiplatform.v1beta1', manifest={ - "TrainingPipeline", - "InputDataConfig", - "FractionSplit", - "FilterSplit", - "PredefinedSplit", - "TimestampSplit", + 'TrainingPipeline', + 'InputDataConfig', + 'FractionSplit', + 'FilterSplit', + 'PredefinedSplit', + 'TimestampSplit', }, ) @@ -143,18 +143,51 @@ class TrainingPipeline(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) - input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) + + input_data_config = proto.Field(proto.MESSAGE, number=3, + message='InputDataConfig', + ) + training_task_definition = proto.Field(proto.STRING, number=4) - training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) - state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + + training_task_inputs = proto.Field(proto.MESSAGE, number=5, + message=struct.Value, + ) + + training_task_metadata = proto.Field(proto.MESSAGE, number=6, + message=struct.Value, + ) + + model_to_upload = proto.Field(proto.MESSAGE, number=7, + message=model.Model, + ) + + state = proto.Field(proto.ENUM, number=9, + enum=pipeline_state.PipelineState, + ) + + error = proto.Field(proto.MESSAGE, number=10, + message=status.Status, + ) + + create_time = proto.Field(proto.MESSAGE, number=11, + message=timestamp.Timestamp, + ) + + start_time = proto.Field(proto.MESSAGE, number=12, + message=timestamp.Timestamp, + ) + + end_time = proto.Field(proto.MESSAGE, number=13, + message=timestamp.Timestamp, + ) + + update_time = proto.Field(proto.MESSAGE, number=14, + message=timestamp.Timestamp, + ) + labels = proto.MapField(proto.STRING, proto.STRING, number=15) @@ -177,17 +210,31 @@ class InputDataConfig(proto.Message): Split based on the timestamp of the input data pieces. gcs_destination (~.io.GcsDestination): - The Google Cloud Storage location. + The Google Cloud Storage location where the output is to be + written to. In the given directory a new directory will be + created with name: + ``dataset---`` + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format. All training input data will be written into that + directory. The AI Platform environment variables representing Google Cloud Storage data URIs will always be represented in the Google Cloud Storage wildcard format to support sharded - data. e.g.: "gs://.../training-\* + data. e.g.: "gs://.../training-*.jsonl" - AIP_DATA_FORMAT = "jsonl". - - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" - - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" - - AIP_TEST_DATA_URI = "gcs_destination/test-*". + - AIP_TRAINING_DATA_URI = + + "gcs_destination/dataset---/training-*.jsonl" + + - AIP_VALIDATION_DATA_URI = + + "gcs_destination/dataset---/validation-*.jsonl" + + - AIP_TEST_DATA_URI = + + "gcs_destination/dataset---/test-*.jsonl". dataset_id (str): Required. The ID of the Dataset in the same Project and Location which data will be used to train the Model. The @@ -237,13 +284,30 @@ class InputDataConfig(proto.Message): ``annotation_schema_uri``. """ - fraction_split = proto.Field(proto.MESSAGE, number=2, message="FractionSplit",) - filter_split = proto.Field(proto.MESSAGE, number=3, message="FilterSplit",) - predefined_split = proto.Field(proto.MESSAGE, number=4, message="PredefinedSplit",) - timestamp_split = proto.Field(proto.MESSAGE, number=5, message="TimestampSplit",) - gcs_destination = proto.Field(proto.MESSAGE, number=8, message=io.GcsDestination,) + fraction_split = proto.Field(proto.MESSAGE, number=2, oneof='split', + message='FractionSplit', + ) + + filter_split = proto.Field(proto.MESSAGE, number=3, oneof='split', + message='FilterSplit', + ) + + predefined_split = proto.Field(proto.MESSAGE, number=4, oneof='split', + message='PredefinedSplit', + ) + + timestamp_split = proto.Field(proto.MESSAGE, number=5, oneof='split', + message='TimestampSplit', + ) + + gcs_destination = proto.Field(proto.MESSAGE, number=8, oneof='destination', + message=io.GcsDestination, + ) + dataset_id = proto.Field(proto.STRING, number=1) + annotations_filter = proto.Field(proto.STRING, number=6) + annotation_schema_uri = proto.Field(proto.STRING, number=9) @@ -269,7 +333,9 @@ class FractionSplit(proto.Message): """ training_fraction = proto.Field(proto.DOUBLE, number=1) + validation_fraction = proto.Field(proto.DOUBLE, number=2) + test_fraction = proto.Field(proto.DOUBLE, number=3) @@ -312,7 +378,9 @@ class FilterSplit(proto.Message): """ training_filter = proto.Field(proto.STRING, number=1) + validation_filter = proto.Field(proto.STRING, number=2) + test_filter = proto.Field(proto.STRING, number=3) @@ -363,8 +431,11 @@ class TimestampSplit(proto.Message): """ training_fraction = proto.Field(proto.DOUBLE, number=1) + validation_fraction = proto.Field(proto.DOUBLE, number=2) + test_fraction = proto.Field(proto.DOUBLE, number=3) + key = proto.Field(proto.STRING, number=4) diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index ce868edc27..6e54a37598 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -19,7 +19,10 @@ __protobuf__ = proto.module( - package="google.cloud.aiplatform.v1beta1", manifest={"UserActionReference",}, + package='google.cloud.aiplatform.v1beta1', + manifest={ + 'UserActionReference', + }, ) @@ -44,8 +47,10 @@ class UserActionReference(proto.Message): "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field(proto.STRING, number=2) + operation = proto.Field(proto.STRING, number=1, oneof='reference') + + data_labeling_job = proto.Field(proto.STRING, number=2, oneof='reference') + method = proto.Field(proto.STRING, number=3) diff --git a/mypy.ini b/mypy.ini index f23e6b533a..4505b48543 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -python_version = 3.5 +python_version = 3.6 namespace_packages = True diff --git a/noxfile.py b/noxfile.py index 615e2c6793..4f20df5c36 100644 --- a/noxfile.py +++ b/noxfile.py @@ -40,7 +40,9 @@ def lint(session): """ session.install("flake8", BLACK_VERSION) session.run( - "black", "--check", *BLACK_PATHS, + "black", + "--check", + *BLACK_PATHS, ) session.run("flake8", "google", "tests") @@ -57,7 +59,8 @@ def blacken(session): """ session.install(BLACK_VERSION) session.run( - "black", *BLACK_PATHS, + "black", + *BLACK_PATHS, ) @@ -71,7 +74,7 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. session.install("asyncmock", "pytest-asyncio") - + session.install("mock", "pytest", "pytest-cov") session.install("-e", ".") @@ -90,7 +93,6 @@ def default(session): *session.posargs, ) - @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) def unit(session): """Run the unit test suite.""" @@ -104,7 +106,7 @@ def system(session): system_test_folder_path = os.path.join("tests", "system") # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. - if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": + if os.environ.get("RUN_SYSTEM_TESTS", "true") == 'false': session.skip("RUN_SYSTEM_TESTS is set to false, skipping") # Sanity check: Only run tests if the environment variable is set. if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): @@ -121,11 +123,10 @@ def system(session): # Install all test dependencies, then install this package into the # virtualenv's dist-packages. - session.install( - "mock", "pytest", "google-cloud-testutils", - ) + session.install("mock", "pytest", "google-cloud-testutils", ) session.install("-e", ".") + # Run py.test against the system tests. if system_test_exists: session.run("py.test", "--quiet", system_test_path, *session.posargs) @@ -133,6 +134,7 @@ def system(session): session.run("py.test", "--quiet", system_test_folder_path, *session.posargs) + @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. @@ -145,15 +147,14 @@ def cover(session): session.run("coverage", "erase") - @nox.session(python=DEFAULT_PYTHON_VERSION) def docs(session): """Build the docs for this library.""" - session.install("-e", ".") - session.install("sphinx", "alabaster", "recommonmark") + session.install('-e', '.') + session.install('sphinx', 'alabaster', 'recommonmark') - shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + shutil.rmtree(os.path.join('docs', '_build'), ignore_errors=True) session.run( "sphinx-build", "-T", # show full traceback on exception diff --git a/scripts/fixup_aiplatform_v1beta1_keywords.py b/scripts/fixup_aiplatform_v1beta1_keywords.py index 26a51a02c1..7188a7d5bc 100644 --- a/scripts/fixup_aiplatform_v1beta1_keywords.py +++ b/scripts/fixup_aiplatform_v1beta1_keywords.py @@ -40,6 +40,7 @@ def partition( class aiplatformCallTransformer(cst.CSTTransformer): CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { + 'batch_migrate_resources': ('parent', 'migrate_resource_requests', ), 'cancel_batch_prediction_job': ('name', ), 'cancel_custom_job': ('name', ), 'cancel_data_labeling_job': ('name', ), @@ -93,6 +94,7 @@ class aiplatformCallTransformer(cst.CSTTransformer): 'list_specialist_pools': ('parent', 'page_size', 'page_token', 'read_mask', ), 'list_training_pipelines': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), 'predict': ('endpoint', 'instances', 'parameters', ), + 'search_migratable_resources': ('parent', 'page_size', 'page_token', ), 'undeploy_model': ('endpoint', 'deployed_model_id', 'traffic_split', ), 'update_dataset': ('dataset', 'update_mask', ), 'update_endpoint': ('endpoint', 'update_mask', ), diff --git a/setup.py b/setup.py index 7cf86480e5..8f159e0dc7 100644 --- a/setup.py +++ b/setup.py @@ -15,46 +15,36 @@ # limitations under the License. # -import io -import os import setuptools # type: ignore -version = "0.2.0" - -package_root = os.path.abspath(os.path.dirname(__file__)) - -readme_filename = os.path.join(package_root, "README.rst") -with io.open(readme_filename, encoding="utf-8") as readme_file: - readme = readme_file.read() - setuptools.setup( - name="google-cloud-aiplatform", - version=version, - long_description=readme, - author="Google LLC", - author_email="googleapis-packages@google.com", - license="Apache 2.0", - url="https://github.com/googleapis/python-documentai", + name='google-cloud-aiplatform', + version='0.3.0', packages=setuptools.PEP420PackageFinder.find(), - namespace_packages=("google", "google.cloud"), - platforms="Posix; MacOS X; Windows", + namespace_packages=('google', 'google.cloud'), + platforms='Posix; MacOS X; Windows', include_package_data=True, install_requires=( - "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", - "libcst >= 0.2.5", - "proto-plus >= 1.4.0", + 'google-api-core[grpc] >= 1.22.2, < 2.0.0dev', + 'libcst >= 0.2.5', + 'proto-plus >= 1.4.0', + 'mock >= 4.0.2', + 'google-cloud-storage >= 1.26.0', ), - python_requires=">=3.6", + python_requires='>=3.6', + scripts=[ + 'scripts/fixup_aiplatform_v1beta1_keywords.py', + ], classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Topic :: Internet", - "Topic :: Software Development :: Libraries :: Python Modules", + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Internet', + 'Topic :: Software Development :: Libraries :: Python Modules', ], zip_safe=False, ) diff --git a/synth.metadata b/synth.metadata index 2f581d2593..a999634172 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,14 +4,14 @@ "git": { "name": ".", "remote": "https://github.com/dizcology/python-aiplatform.git", - "sha": "7e83ff65457e88aa155e68ddd959933a68da46af" + "sha": "0fc50abf7571c93a7810506e73c05a72b9f6efc0" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "487eba79f8260e34205d8ceb1ebcc65685085e19" + "sha": "77c5ba85e05950f5b19ce8a553c1c0db2fba9896" } } ], diff --git a/synth.py b/synth.py index 31c3e11493..ce8c810d80 100644 --- a/synth.py +++ b/synth.py @@ -35,7 +35,7 @@ # version="v1beta1", # bazel_target="//google/cloud/aiplatform/v1beta1:aiplatform-v1beta1-py", # ) -library = gapic.py_library("aiplatform", "v1beta1", generator_version="0.20") +library = gapic.py_library("aiplatform", "v1beta1") s.move( library, @@ -45,7 +45,7 @@ "README.rst", "docs/index.rst", "google/cloud/aiplatform/__init__.py", - "tests/unit/aiplatform_v1beta1/test_prediction_service.py", + "tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py", ], ) diff --git a/tests/unit/gapic/aiplatform_v1beta1/__init__.py b/tests/unit/gapic/aiplatform_v1beta1/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py new file mode 100644 index 0000000000..002b1afc4e --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -0,0 +1,3687 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceClient +from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers +from google.cloud.aiplatform_v1beta1.services.dataset_service import transports +from google.cloud.aiplatform_v1beta1.types import annotation +from google.cloud.aiplatform_v1beta1.types import annotation_spec +from google.cloud.aiplatform_v1beta1.types import data_item +from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert DatasetServiceClient._get_default_mtls_endpoint(None) is None + assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [DatasetServiceClient, DatasetServiceAsyncClient]) +def test_dataset_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_dataset_service_client_get_transport_class(): + transport = DatasetServiceClient.get_transport_class() + assert transport == transports.DatasetServiceGrpcTransport + + transport = DatasetServiceClient.get_transport_class("grpc") + assert transport == transports.DatasetServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) +def test_dataset_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_dataset_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = DatasetServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.CreateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_dataset_from_dict(): + test_create_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_dataset_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.CreateDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_dataset_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.CreateDatasetRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_dataset_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.CreateDatasetRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_dataset_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_dataset( + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].dataset == gca_dataset.Dataset(name='name_value') + + +def test_create_dataset_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_dataset_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_dataset( + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].dataset == gca_dataset.Dataset(name='name_value') + + +@pytest.mark.asyncio +async def test_create_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), + ) + + +def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset( + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + + ) + + response = client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, dataset.Dataset) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.etag == 'etag_value' + + +def test_get_dataset_from_dict(): + test_get_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_dataset_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.GetDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) + + response = await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, dataset.Dataset) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.etag == 'etag_value' + + +def test_get_dataset_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetDatasetRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: + call.return_value = dataset.Dataset() + + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_dataset_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetDatasetRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) + + await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_dataset_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_dataset( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_dataset_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_dataset( + dataset_service.GetDatasetRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_dataset_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_dataset( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_dataset( + dataset_service.GetDatasetRequest(), + name='name_value', + ) + + +def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset( + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + + ) + + response = client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.UpdateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_dataset.Dataset) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.etag == 'etag_value' + + +def test_update_dataset_from_dict(): + test_update_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_dataset_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.UpdateDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) + + response = await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_dataset.Dataset) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.etag == 'etag_value' + + +def test_update_dataset_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.UpdateDatasetRequest() + request.dataset.name = 'dataset.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: + call.return_value = gca_dataset.Dataset() + + client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_dataset_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.UpdateDatasetRequest() + request.dataset.name = 'dataset.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) + + await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] + + +def test_update_dataset_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_dataset( + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].dataset == gca_dataset.Dataset(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_dataset_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_dataset_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_dataset( + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].dataset == gca_dataset.Dataset(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDatasetsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_datasets_from_dict(): + test_list_datasets(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_datasets_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ListDatasetsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_datasets_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDatasetsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + call.return_value = dataset_service.ListDatasetsResponse() + + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_datasets_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDatasetsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + + await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_datasets_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_datasets( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_datasets_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_datasets( + dataset_service.ListDatasetsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_datasets_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_datasets( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_datasets_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_datasets( + dataset_service.ListDatasetsRequest(), + parent='parent_value', + ) + + +def test_list_datasets_pager(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_datasets(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, dataset.Dataset) + for i in results) + +def test_list_datasets_pages(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], + ), + RuntimeError, + ) + pages = list(client.list_datasets(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_datasets_async_pager(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_datasets(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, dataset.Dataset) + for i in responses) + +@pytest.mark.asyncio +async def test_list_datasets_async_pages(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_datasets(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.DeleteDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_dataset_from_dict(): + test_delete_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_dataset_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.DeleteDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_dataset_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.DeleteDatasetRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_dataset_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.DeleteDatasetRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_dataset_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_dataset( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_dataset_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_dataset( + dataset_service.DeleteDatasetRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_dataset_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_dataset( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_dataset( + dataset_service.DeleteDatasetRequest(), + name='name_value', + ) + + +def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ImportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_import_data_from_dict(): + test_import_data(request_type=dict) + + +@pytest.mark.asyncio +async def test_import_data_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ImportDataRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_import_data_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ImportDataRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_import_data_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ImportDataRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_import_data_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.import_data( + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + + +def test_import_data_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.import_data( + dataset_service.ImportDataRequest(), + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + ) + + +@pytest.mark.asyncio +async def test_import_data_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.import_data( + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + + +@pytest.mark.asyncio +async def test_import_data_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.import_data( + dataset_service.ImportDataRequest(), + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + ) + + +def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ExportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_data_from_dict(): + test_export_data(request_type=dict) + + +@pytest.mark.asyncio +async def test_export_data_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ExportDataRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_data_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ExportDataRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_export_data_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ExportDataRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_export_data_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_data( + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + + +def test_export_data_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_data( + dataset_service.ExportDataRequest(), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + ) + + +@pytest.mark.asyncio +async def test_export_data_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_data( + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + + +@pytest.mark.asyncio +async def test_export_data_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_data( + dataset_service.ExportDataRequest(), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + ) + + +def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDataItemsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataItemsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_data_items_from_dict(): + test_list_data_items(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_data_items_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ListDataItemsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataItemsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_data_items_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + call.return_value = dataset_service.ListDataItemsResponse() + + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_data_items_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + + await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_data_items_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_data_items( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_data_items_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_items( + dataset_service.ListDataItemsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_data_items_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_data_items( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_data_items_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_data_items( + dataset_service.ListDataItemsRequest(), + parent='parent_value', + ) + + +def test_list_data_items_pager(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token='abc', + ), + dataset_service.ListDataItemsResponse( + data_items=[], + next_page_token='def', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_data_items(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, data_item.DataItem) + for i in results) + +def test_list_data_items_pages(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token='abc', + ), + dataset_service.ListDataItemsResponse( + data_items=[], + next_page_token='def', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], + ), + RuntimeError, + ) + pages = list(client.list_data_items(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_data_items_async_pager(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token='abc', + ), + dataset_service.ListDataItemsResponse( + data_items=[], + next_page_token='def', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_data_items(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, data_item.DataItem) + for i in responses) + +@pytest.mark.asyncio +async def test_list_data_items_async_pages(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token='abc', + ), + dataset_service.ListDataItemsResponse( + data_items=[], + next_page_token='def', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', + ), + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_data_items(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec( + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + + ) + + response = client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + +def test_get_annotation_spec_from_dict(): + test_get_annotation_spec(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.GetAnnotationSpecRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( + name='name_value', + display_name='display_name_value', + etag='etag_value', + )) + + response = await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.etag == 'etag_value' + + +def test_get_annotation_spec_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), + '__call__') as call: + call.return_value = annotation_spec.AnnotationSpec() + + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_annotation_spec_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + + await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_annotation_spec_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_annotation_spec( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_annotation_spec_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_annotation_spec( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), + name='name_value', + ) + + +def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListAnnotationsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_annotations_from_dict(): + test_list_annotations(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_annotations_async(transport: str = 'grpc_asyncio'): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ListAnnotationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_annotations_field_headers(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + call.return_value = dataset_service.ListAnnotationsResponse() + + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_annotations_field_headers_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + + await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_annotations_flattened(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_annotations( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_annotations_flattened_error(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_annotations( + dataset_service.ListAnnotationsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_annotations( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_error_async(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_annotations( + dataset_service.ListAnnotationsRequest(), + parent='parent_value', + ) + + +def test_list_annotations_pager(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token='abc', + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token='def', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_annotations(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, annotation.Annotation) + for i in results) + +def test_list_annotations_pages(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token='abc', + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token='def', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + pages = list(client.list_annotations(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_annotations_async_pager(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token='abc', + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token='def', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_annotations(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, annotation.Annotation) + for i in responses) + +@pytest.mark.asyncio +async def test_list_annotations_async_pages(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token='abc', + ), + dataset_service.ListAnnotationsResponse( + annotations=[], + next_page_token='def', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', + ), + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_annotations(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = DatasetServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.DatasetServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.DatasetServiceGrpcTransport, + ) + + +def test_dataset_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.DatasetServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_dataset_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.DatasetServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_dataset', + 'get_dataset', + 'update_dataset', + 'list_datasets', + 'delete_dataset', + 'import_data', + 'export_data', + 'list_data_items', + 'get_annotation_spec', + 'list_annotations', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_dataset_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.DatasetServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_dataset_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.DatasetServiceTransport() + adc.assert_called_once() + + +def test_dataset_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + DatasetServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_dataset_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_dataset_service_host_no_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_dataset_service_host_with_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_dataset_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_dataset_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.DatasetServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +def test_dataset_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +def test_dataset_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_dataset_service_grpc_lro_client(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_dataset_service_grpc_lro_async_client(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + +def test_annotation_path(): + project = "squid" + location = "clam" + dataset = "whelk" + data_item = "octopus" + annotation = "oyster" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) + assert expected == actual + + +def test_parse_annotation_path(): + expected = { + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", + + } + path = DatasetServiceClient.annotation_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_annotation_path(path) + assert expected == actual + +def test_annotation_spec_path(): + project = "scallop" + location = "abalone" + dataset = "squid" + annotation_spec = "clam" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) + assert expected == actual + + +def test_parse_annotation_spec_path(): + expected = { + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", + + } + path = DatasetServiceClient.annotation_spec_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_annotation_spec_path(path) + assert expected == actual + +def test_data_item_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + data_item = "nautilus" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) + assert expected == actual + + +def test_parse_data_item_path(): + expected = { + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", + + } + path = DatasetServiceClient.data_item_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_data_item_path(path) + assert expected == actual + +def test_dataset_path(): + project = "whelk" + location = "octopus" + dataset = "oyster" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = DatasetServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + + } + path = DatasetServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_dataset_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "winkle" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = DatasetServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nautilus", + + } + path = DatasetServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "scallop" + + expected = "folders/{folder}".format(folder=folder, ) + actual = DatasetServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "abalone", + + } + path = DatasetServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "squid" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = DatasetServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "clam", + + } + path = DatasetServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "whelk" + + expected = "projects/{project}".format(project=project, ) + actual = DatasetServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "octopus", + + } + path = DatasetServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "oyster" + location = "nudibranch" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = DatasetServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "cuttlefish", + "location": "mussel", + + } + path = DatasetServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = DatasetServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py new file mode 100644 index 0000000000..3d638675f1 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -0,0 +1,2604 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceClient +from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports +from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint_service +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert EndpointServiceClient._get_default_mtls_endpoint(None) is None + assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [EndpointServiceClient, EndpointServiceAsyncClient]) +def test_endpoint_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_endpoint_service_client_get_transport_class(): + transport = EndpointServiceClient.get_transport_class() + assert transport == transports.EndpointServiceGrpcTransport + + transport = EndpointServiceClient.get_transport_class("grpc") + assert transport == transports.EndpointServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) +def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_endpoint_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = EndpointServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.CreateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_endpoint_from_dict(): + test_create_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_endpoint_async(transport: str = 'grpc_asyncio'): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.CreateEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_endpoint_field_headers(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.CreateEndpointRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.CreateEndpointRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_endpoint_flattened(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_endpoint( + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + + +def test_create_endpoint_flattened_error(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_endpoint( + endpoint_service.CreateEndpointRequest(), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_endpoint_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_endpoint( + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + + +@pytest.mark.asyncio +async def test_create_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_endpoint( + endpoint_service.CreateEndpointRequest(), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), + ) + + +def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint( + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + + ) + + response = client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.GetEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, endpoint.Endpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +def test_get_endpoint_from_dict(): + test_get_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_endpoint_async(transport: str = 'grpc_asyncio'): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.GetEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) + + response = await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, endpoint.Endpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +def test_get_endpoint_field_headers(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.GetEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: + call.return_value = endpoint.Endpoint() + + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.GetEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) + + await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_endpoint_flattened(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_endpoint_flattened_error(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_endpoint( + endpoint_service.GetEndpointRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_endpoint_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_endpoint( + endpoint_service.GetEndpointRequest(), + name='name_value', + ) + + +def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.ListEndpointsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListEndpointsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_endpoints_from_dict(): + test_list_endpoints(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_endpoints_async(transport: str = 'grpc_asyncio'): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.ListEndpointsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListEndpointsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_endpoints_field_headers(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.ListEndpointsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + call.return_value = endpoint_service.ListEndpointsResponse() + + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_endpoints_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.ListEndpointsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + + await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_endpoints_flattened(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_endpoints( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_endpoints_flattened_error(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_endpoints( + endpoint_service.ListEndpointsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_endpoints_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_endpoints( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_endpoints_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_endpoints( + endpoint_service.ListEndpointsRequest(), + parent='parent_value', + ) + + +def test_list_endpoints_pager(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token='abc', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], + next_page_token='def', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_endpoints(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, endpoint.Endpoint) + for i in results) + +def test_list_endpoints_pages(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token='abc', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], + next_page_token='def', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + ), + RuntimeError, + ) + pages = list(client.list_endpoints(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_endpoints_async_pager(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token='abc', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], + next_page_token='def', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_endpoints(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, endpoint.Endpoint) + for i in responses) + +@pytest.mark.asyncio +async def test_list_endpoints_async_pages(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token='abc', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], + next_page_token='def', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', + ), + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_endpoints(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint( + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + + ) + + response = client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UpdateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_endpoint.Endpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +def test_update_endpoint_from_dict(): + test_update_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_endpoint_async(transport: str = 'grpc_asyncio'): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.UpdateEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) + + response = await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_endpoint.Endpoint) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.etag == 'etag_value' + + +def test_update_endpoint_field_headers(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UpdateEndpointRequest() + request.endpoint.name = 'endpoint.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + call.return_value = gca_endpoint.Endpoint() + + client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UpdateEndpointRequest() + request.endpoint.name = 'endpoint.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + + await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] + + +def test_update_endpoint_flattened(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_endpoint( + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_endpoint_flattened_error(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_endpoint( + endpoint_service.UpdateEndpointRequest(), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_endpoint_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_endpoint( + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_endpoint( + endpoint_service.UpdateEndpointRequest(), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeleteEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_endpoint_from_dict(): + test_delete_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_endpoint_async(transport: str = 'grpc_asyncio'): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.DeleteEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_endpoint_field_headers(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeleteEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeleteEndpointRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_endpoint_flattened(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_endpoint_flattened_error(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_endpoint( + endpoint_service.DeleteEndpointRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_endpoint_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_endpoint( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_endpoint( + endpoint_service.DeleteEndpointRequest(), + name='name_value', + ) + + +def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_deploy_model_from_dict(): + test_deploy_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_deploy_model_async(transport: str = 'grpc_asyncio'): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.DeployModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_deploy_model_field_headers(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeployModelRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_deploy_model_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeployModelRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +def test_deploy_model_flattened(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.deploy_model( + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + + assert args[0].traffic_split == {'key_value': 541} + + +def test_deploy_model_flattened_error(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.deploy_model( + endpoint_service.DeployModelRequest(), + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, + ) + + +@pytest.mark.asyncio +async def test_deploy_model_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.deploy_model( + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + + assert args[0].traffic_split == {'key_value': 541} + + +@pytest.mark.asyncio +async def test_deploy_model_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.deploy_model( + endpoint_service.DeployModelRequest(), + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, + ) + + +def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UndeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_undeploy_model_from_dict(): + test_undeploy_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_undeploy_model_async(transport: str = 'grpc_asyncio'): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.UndeployModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_undeploy_model_field_headers(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UndeployModelRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_undeploy_model_field_headers_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UndeployModelRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +def test_undeploy_model_flattened(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.undeploy_model( + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].deployed_model_id == 'deployed_model_id_value' + + assert args[0].traffic_split == {'key_value': 541} + + +def test_undeploy_model_flattened_error(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.undeploy_model( + endpoint_service.UndeployModelRequest(), + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, + ) + + +@pytest.mark.asyncio +async def test_undeploy_model_flattened_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.undeploy_model( + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].deployed_model_id == 'deployed_model_id_value' + + assert args[0].traffic_split == {'key_value': 541} + + +@pytest.mark.asyncio +async def test_undeploy_model_flattened_error_async(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.undeploy_model( + endpoint_service.UndeployModelRequest(), + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = EndpointServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.EndpointServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.EndpointServiceGrpcTransport, + ) + + +def test_endpoint_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.EndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_endpoint_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.EndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_endpoint', + 'get_endpoint', + 'list_endpoints', + 'update_endpoint', + 'delete_endpoint', + 'deploy_model', + 'undeploy_model', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_endpoint_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.EndpointServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_endpoint_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.EndpointServiceTransport() + adc.assert_called_once() + + +def test_endpoint_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + EndpointServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_endpoint_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_endpoint_service_host_no_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_endpoint_service_host_with_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_endpoint_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_endpoint_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.EndpointServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +def test_endpoint_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +def test_endpoint_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_endpoint_service_grpc_lro_client(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_service_grpc_lro_async_client(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + actual = EndpointServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + + } + path = EndpointServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_endpoint_path(path) + assert expected == actual + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + actual = EndpointServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + + } + path = EndpointServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = EndpointServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + + } + path = EndpointServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder, ) + actual = EndpointServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + + } + path = EndpointServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = EndpointServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + + } + path = EndpointServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project, ) + actual = EndpointServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + + } + path = EndpointServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = EndpointServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + + } + path = EndpointServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = EndpointServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py new file mode 100644 index 0000000000..fa57d58228 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -0,0 +1,6289 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient +from google.cloud.aiplatform_v1beta1.services.job_service import pagers +from google.cloud.aiplatform_v1beta1.services.job_service import transports +from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import completion_stats +from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import study +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as gp_any # type: ignore +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert JobServiceClient._get_default_mtls_endpoint(None) is None + assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) +def test_job_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_job_service_client_get_transport_class(): + transport = JobServiceClient.get_transport_class() + assert transport == transports.JobServiceGrpcTransport + + transport = JobServiceClient.get_transport_class("grpc") + assert transport == transports.JobServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) +def test_job_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_job_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = JobServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob( + name='name_value', + + display_name='display_name_value', + + state=job_state.JobState.JOB_STATE_QUEUED, + + ) + + response = client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_custom_job.CustomJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_custom_job_from_dict(): + test_create_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_custom_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) + + response = await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_custom_job.CustomJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_custom_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateCustomJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), + '__call__') as call: + call.return_value = gca_custom_job.CustomJob() + + client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_custom_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateCustomJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + + await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_custom_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_custom_job( + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + + +def test_create_custom_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_custom_job( + job_service.CreateCustomJobRequest(), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_custom_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_custom_job( + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + + +@pytest.mark.asyncio +async def test_create_custom_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_custom_job( + job_service.CreateCustomJobRequest(), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), + ) + + +def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob( + name='name_value', + + display_name='display_name_value', + + state=job_state.JobState.JOB_STATE_QUEUED, + + ) + + response = client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, custom_job.CustomJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_custom_job_from_dict(): + test_get_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_custom_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) + + response = await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, custom_job.CustomJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_custom_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetCustomJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + call.return_value = custom_job.CustomJob() + + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_custom_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetCustomJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + + await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_custom_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_custom_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_custom_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_custom_job( + job_service.GetCustomJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_custom_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_custom_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_custom_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_custom_job( + job_service.GetCustomJobRequest(), + name='name_value', + ) + + +def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListCustomJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCustomJobsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_custom_jobs_from_dict(): + test_list_custom_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListCustomJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCustomJobsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_custom_jobs_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListCustomJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + call.return_value = job_service.ListCustomJobsResponse() + + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_custom_jobs_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListCustomJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + + await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_custom_jobs_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_custom_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_custom_jobs_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_custom_jobs( + job_service.ListCustomJobsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_custom_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_custom_jobs_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_custom_jobs( + job_service.ListCustomJobsRequest(), + parent='parent_value', + ) + + +def test_list_custom_jobs_pager(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_custom_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, custom_job.CustomJob) + for i in results) + +def test_list_custom_jobs_pages(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_custom_jobs(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_pager(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_custom_jobs(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, custom_job.CustomJob) + for i in responses) + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_pages(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token='abc', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[], + next_page_token='def', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_custom_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_custom_job_from_dict(): + test_delete_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_custom_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_custom_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteCustomJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_custom_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteCustomJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_custom_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_custom_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_custom_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_custom_job( + job_service.DeleteCustomJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_custom_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_custom_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_custom_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_custom_job( + job_service.DeleteCustomJobRequest(), + name='name_value', + ) + + +def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelCustomJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_custom_job_from_dict(): + test_cancel_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_custom_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelCustomJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), + '__call__') as call: + call.return_value = None + + client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_cancel_custom_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelCustomJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_cancel_custom_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_custom_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_cancel_custom_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_custom_job( + job_service.CancelCustomJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_cancel_custom_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_custom_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_cancel_custom_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_custom_job( + job_service.CancelCustomJobRequest(), + name='name_value', + ) + + +def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob( + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + + labeler_count=1375, + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + + state=job_state.JobState.JOB_STATE_QUEUED, + + labeling_progress=1810, + + specialist_pools=['specialist_pools_value'], + + ) + + response = client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.datasets == ['datasets_value'] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == 'instruction_uri_value' + + assert response.inputs_schema_uri == 'inputs_schema_uri_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ['specialist_pools_value'] + + +def test_create_data_labeling_job_from_dict(): + test_create_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) + + response = await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.datasets == ['datasets_value'] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == 'instruction_uri_value' + + assert response.inputs_schema_uri == 'inputs_schema_uri_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ['specialist_pools_value'] + + +def test_create_data_labeling_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateDataLabelingJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), + '__call__') as call: + call.return_value = gca_data_labeling_job.DataLabelingJob() + + client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateDataLabelingJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + + await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_data_labeling_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_data_labeling_job( + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + + +def test_create_data_labeling_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_data_labeling_job( + job_service.CreateDataLabelingJobRequest(), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_data_labeling_job( + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_data_labeling_job( + job_service.CreateDataLabelingJobRequest(), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + ) + + +def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob( + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + + labeler_count=1375, + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + + state=job_state.JobState.JOB_STATE_QUEUED, + + labeling_progress=1810, + + specialist_pools=['specialist_pools_value'], + + ) + + response = client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.datasets == ['datasets_value'] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == 'instruction_uri_value' + + assert response.inputs_schema_uri == 'inputs_schema_uri_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ['specialist_pools_value'] + + +def test_get_data_labeling_job_from_dict(): + test_get_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) + + response = await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.datasets == ['datasets_value'] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == 'instruction_uri_value' + + assert response.inputs_schema_uri == 'inputs_schema_uri_value' + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ['specialist_pools_value'] + + +def test_get_data_labeling_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetDataLabelingJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), + '__call__') as call: + call.return_value = data_labeling_job.DataLabelingJob() + + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetDataLabelingJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + + await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_data_labeling_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_data_labeling_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_data_labeling_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_data_labeling_job( + job_service.GetDataLabelingJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_data_labeling_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_data_labeling_job( + job_service.GetDataLabelingJobRequest(), + name='name_value', + ) + + +def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListDataLabelingJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataLabelingJobsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_data_labeling_jobs_from_dict(): + test_list_data_labeling_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListDataLabelingJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_data_labeling_jobs_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListDataLabelingJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + call.return_value = job_service.ListDataLabelingJobsResponse() + + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListDataLabelingJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + + await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_data_labeling_jobs_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_data_labeling_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_data_labeling_jobs_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_labeling_jobs( + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_data_labeling_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_data_labeling_jobs( + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', + ) + + +def test_list_data_labeling_jobs_pager(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token='abc', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], + next_page_token='def', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_data_labeling_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in results) + +def test_list_data_labeling_jobs_pages(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token='abc', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], + next_page_token='def', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_data_labeling_jobs(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_pager(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token='abc', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], + next_page_token='def', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_data_labeling_jobs(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in responses) + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_pages(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token='abc', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], + next_page_token='def', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_data_labeling_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_data_labeling_job_from_dict(): + test_delete_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_data_labeling_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteDataLabelingJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteDataLabelingJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_data_labeling_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_data_labeling_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_data_labeling_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_data_labeling_job( + job_service.DeleteDataLabelingJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_data_labeling_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_data_labeling_job( + job_service.DeleteDataLabelingJobRequest(), + name='name_value', + ) + + +def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_data_labeling_job_from_dict(): + test_cancel_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_data_labeling_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelDataLabelingJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), + '__call__') as call: + call.return_value = None + + client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelDataLabelingJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_cancel_data_labeling_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_data_labeling_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_cancel_data_labeling_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_data_labeling_job( + job_service.CancelDataLabelingJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_data_labeling_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_data_labeling_job( + job_service.CancelDataLabelingJobRequest(), + name='name_value', + ) + + +def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + + display_name='display_name_value', + + max_trial_count=1609, + + parallel_trial_count=2128, + + max_failed_trial_count=2317, + + state=job_state.JobState.JOB_STATE_QUEUED, + + ) + + response = client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_hyperparameter_tuning_job_from_dict(): + test_create_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) + + response = await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateHyperparameterTuningJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateHyperparameterTuningJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + + await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_hyperparameter_tuning_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_hyperparameter_tuning_job( + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + + +def test_create_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_hyperparameter_tuning_job( + job_service.CreateHyperparameterTuningJobRequest(), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_hyperparameter_tuning_job( + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_hyperparameter_tuning_job( + job_service.CreateHyperparameterTuningJobRequest(), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + ) + + +def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + + display_name='display_name_value', + + max_trial_count=1609, + + parallel_trial_count=2128, + + max_failed_trial_count=2317, + + state=job_state.JobState.JOB_STATE_QUEUED, + + ) + + response = client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_hyperparameter_tuning_job_from_dict(): + test_get_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) + + response = await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetHyperparameterTuningJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetHyperparameterTuningJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + + await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_hyperparameter_tuning_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_hyperparameter_tuning_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_hyperparameter_tuning_job( + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_hyperparameter_tuning_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_hyperparameter_tuning_job( + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', + ) + + +def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_hyperparameter_tuning_jobs_from_dict(): + test_list_hyperparameter_tuning_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListHyperparameterTuningJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_hyperparameter_tuning_jobs_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListHyperparameterTuningJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListHyperparameterTuningJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + + await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_hyperparameter_tuning_jobs_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_hyperparameter_tuning_jobs_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_hyperparameter_tuning_jobs( + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_hyperparameter_tuning_jobs( + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', + ) + + +def test_list_hyperparameter_tuning_jobs_pager(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='abc', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], + next_page_token='def', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='ghi', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_hyperparameter_tuning_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results) + +def test_list_hyperparameter_tuning_jobs_pages(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='abc', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], + next_page_token='def', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='ghi', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_pager(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='abc', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], + next_page_token='def', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='ghi', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_hyperparameter_tuning_jobs(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses) + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_pages(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='abc', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], + next_page_token='def', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token='ghi', + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_hyperparameter_tuning_job_from_dict(): + test_delete_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteHyperparameterTuningJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteHyperparameterTuningJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_hyperparameter_tuning_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_hyperparameter_tuning_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_hyperparameter_tuning_job( + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_hyperparameter_tuning_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_hyperparameter_tuning_job( + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', + ) + + +def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_hyperparameter_tuning_job_from_dict(): + test_cancel_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelHyperparameterTuningJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = None + + client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelHyperparameterTuningJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_cancel_hyperparameter_tuning_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_hyperparameter_tuning_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_cancel_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_hyperparameter_tuning_job( + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_hyperparameter_tuning_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_hyperparameter_tuning_job( + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', + ) + + +def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob( + name='name_value', + + display_name='display_name_value', + + model='model_value', + + generate_explanation=True, + + state=job_state.JobState.JOB_STATE_QUEUED, + + ) + + response = client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.model == 'model_value' + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_batch_prediction_job_from_dict(): + test_create_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + )) + + response = await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.model == 'model_value' + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_batch_prediction_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateBatchPredictionJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), + '__call__') as call: + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateBatchPredictionJobRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + + await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_batch_prediction_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_batch_prediction_job( + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + + +def test_create_batch_prediction_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_batch_prediction_job( + job_service.CreateBatchPredictionJobRequest(), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_batch_prediction_job( + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_batch_prediction_job( + job_service.CreateBatchPredictionJobRequest(), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + ) + + +def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob( + name='name_value', + + display_name='display_name_value', + + model='model_value', + + generate_explanation=True, + + state=job_state.JobState.JOB_STATE_QUEUED, + + ) + + response = client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, batch_prediction_job.BatchPredictionJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.model == 'model_value' + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_batch_prediction_job_from_dict(): + test_get_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + )) + + response = await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, batch_prediction_job.BatchPredictionJob) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.model == 'model_value' + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_batch_prediction_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), + '__call__') as call: + call.return_value = batch_prediction_job.BatchPredictionJob() + + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + + await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_batch_prediction_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_batch_prediction_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_batch_prediction_job( + job_service.GetBatchPredictionJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_batch_prediction_job( + job_service.GetBatchPredictionJobRequest(), + name='name_value', + ) + + +def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListBatchPredictionJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBatchPredictionJobsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_batch_prediction_jobs_from_dict(): + test_list_batch_prediction_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListBatchPredictionJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_batch_prediction_jobs_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListBatchPredictionJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + call.return_value = job_service.ListBatchPredictionJobsResponse() + + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListBatchPredictionJobsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + + await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_batch_prediction_jobs_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_batch_prediction_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_batch_prediction_jobs_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_batch_prediction_jobs( + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_batch_prediction_jobs( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_batch_prediction_jobs( + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', + ) + + +def test_list_batch_prediction_jobs_pager(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='abc', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], + next_page_token='def', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_batch_prediction_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in results) + +def test_list_batch_prediction_jobs_pages(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='abc', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], + next_page_token='def', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_batch_prediction_jobs(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_pager(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='abc', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], + next_page_token='def', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_batch_prediction_jobs(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in responses) + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_pages(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='abc', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], + next_page_token='def', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_batch_prediction_job_from_dict(): + test_delete_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_batch_prediction_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_batch_prediction_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_batch_prediction_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', + ) + + +def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_batch_prediction_job_from_dict(): + test_cancel_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio'): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_batch_prediction_job_field_headers(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + call.return_value = None + + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_cancel_batch_prediction_job_flattened(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_cancel_batch_prediction_job_flattened_error(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_batch_prediction_job( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), + name='name_value', + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = JobServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.JobServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.JobServiceGrpcTransport, + transports.JobServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.JobServiceGrpcTransport, + ) + + +def test_job_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.JobServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_job_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.JobServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_custom_job', + 'get_custom_job', + 'list_custom_jobs', + 'delete_custom_job', + 'cancel_custom_job', + 'create_data_labeling_job', + 'get_data_labeling_job', + 'list_data_labeling_jobs', + 'delete_data_labeling_job', + 'cancel_data_labeling_job', + 'create_hyperparameter_tuning_job', + 'get_hyperparameter_tuning_job', + 'list_hyperparameter_tuning_jobs', + 'delete_hyperparameter_tuning_job', + 'cancel_hyperparameter_tuning_job', + 'create_batch_prediction_job', + 'get_batch_prediction_job', + 'list_batch_prediction_jobs', + 'delete_batch_prediction_job', + 'cancel_batch_prediction_job', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_job_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_job_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobServiceTransport() + adc.assert_called_once() + + +def test_job_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + JobServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_job_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_job_service_host_no_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_job_service_host_with_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_job_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.JobServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_job_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.JobServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_job_service_grpc_lro_client(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_job_service_grpc_lro_async_client(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + +def test_batch_prediction_job_path(): + project = "squid" + location = "clam" + batch_prediction_job = "whelk" + + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) + assert expected == actual + + +def test_parse_batch_prediction_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", + + } + path = JobServiceClient.batch_prediction_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_batch_prediction_job_path(path) + assert expected == actual + +def test_custom_job_path(): + project = "cuttlefish" + location = "mussel" + custom_job = "winkle" + + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + actual = JobServiceClient.custom_job_path(project, location, custom_job) + assert expected == actual + + +def test_parse_custom_job_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", + + } + path = JobServiceClient.custom_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_custom_job_path(path) + assert expected == actual + +def test_data_labeling_job_path(): + project = "squid" + location = "clam" + data_labeling_job = "whelk" + + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) + assert expected == actual + + +def test_parse_data_labeling_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", + + } + path = JobServiceClient.data_labeling_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_data_labeling_job_path(path) + assert expected == actual + +def test_dataset_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = JobServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + + } + path = JobServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_dataset_path(path) + assert expected == actual + +def test_hyperparameter_tuning_job_path(): + project = "squid" + location = "clam" + hyperparameter_tuning_job = "whelk" + + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) + assert expected == actual + + +def test_parse_hyperparameter_tuning_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", + + } + path = JobServiceClient.hyperparameter_tuning_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) + assert expected == actual + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + actual = JobServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + + } + path = JobServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = JobServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + + } + path = JobServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder, ) + actual = JobServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + + } + path = JobServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = JobServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + + } + path = JobServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project, ) + actual = JobServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + + } + path = JobServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = JobServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + + } + path = JobServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = JobServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py new file mode 100644 index 0000000000..865bcf4305 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -0,0 +1,1546 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceClient +from google.cloud.aiplatform_v1beta1.services.migration_service import pagers +from google.cloud.aiplatform_v1beta1.services.migration_service import transports +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert MigrationServiceClient._get_default_mtls_endpoint(None) is None + assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [MigrationServiceClient, MigrationServiceAsyncClient]) +def test_migration_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_migration_service_client_get_transport_class(): + transport = MigrationServiceClient.get_transport_class() + assert transport == transports.MigrationServiceGrpcTransport + + transport = MigrationServiceClient.get_transport_class("grpc") + assert transport == transports.MigrationServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +def test_migration_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_migration_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = MigrationServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.SearchMigratableResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.SearchMigratableResourcesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_search_migratable_resources_from_dict(): + test_search_migratable_resources(request_type=dict) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio'): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = migration_service.SearchMigratableResourcesRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_search_migratable_resources_field_headers(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.SearchMigratableResourcesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + call.return_value = migration_service.SearchMigratableResourcesResponse() + + client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_search_migratable_resources_field_headers_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.SearchMigratableResourcesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + + await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_search_migratable_resources_flattened(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.search_migratable_resources( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_search_migratable_resources_flattened_error(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.search_migratable_resources( + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_flattened_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.search_migratable_resources( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_search_migratable_resources_flattened_error_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.search_migratable_resources( + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', + ) + + +def test_search_migratable_resources_pager(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token='abc', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], + next_page_token='def', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.search_migratable_resources(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in results) + +def test_search_migratable_resources_pages(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token='abc', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], + next_page_token='def', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + pages = list(client.search_migratable_resources(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_pager(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token='abc', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], + next_page_token='def', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + async_pager = await client.search_migratable_resources(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in responses) + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_pages(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token='abc', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], + next_page_token='def', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.search_migratable_resources(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.BatchMigrateResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_batch_migrate_resources_from_dict(): + test_batch_migrate_resources(request_type=dict) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio'): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = migration_service.BatchMigrateResourcesRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_batch_migrate_resources_field_headers(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.BatchMigrateResourcesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_field_headers_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.BatchMigrateResourcesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_batch_migrate_resources_flattened(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_migrate_resources( + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + + +def test_batch_migrate_resources_flattened_error(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_migrate_resources( + migration_service.BatchMigrateResourcesRequest(), + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + ) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_flattened_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_migrate_resources( + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_flattened_error_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_migrate_resources( + migration_service.BatchMigrateResourcesRequest(), + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = MigrationServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.MigrationServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.MigrationServiceGrpcTransport, + ) + + +def test_migration_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.MigrationServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_migration_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.MigrationServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'search_migratable_resources', + 'batch_migrate_resources', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_migration_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MigrationServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_migration_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MigrationServiceTransport() + adc.assert_called_once() + + +def test_migration_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + MigrationServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_migration_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_migration_service_host_no_port(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_migration_service_host_with_port(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_migration_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_migration_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.MigrationServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +def test_migration_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +def test_migration_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_migration_service_grpc_lro_client(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_migration_service_grpc_lro_async_client(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + +def test_annotated_dataset_path(): + project = "squid" + dataset = "clam" + annotated_dataset = "whelk" + + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) + assert expected == actual + + +def test_parse_annotated_dataset_path(): + expected = { + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", + + } + path = MigrationServiceClient.annotated_dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_annotated_dataset_path(path) + assert expected == actual + +def test_dataset_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + +def test_dataset_path(): + project = "squid" + location = "clam" + dataset = "whelk" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", + + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + +def test_dataset_path(): + project = "cuttlefish" + dataset = "mussel" + + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + actual = MigrationServiceClient.dataset_path(project, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "winkle", + "dataset": "nautilus", + + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + +def test_model_path(): + project = "scallop" + location = "abalone" + model = "squid" + + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + actual = MigrationServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "clam", + "location": "whelk", + "model": "octopus", + + } + path = MigrationServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_model_path(path) + assert expected == actual + +def test_model_path(): + project = "oyster" + location = "nudibranch" + model = "cuttlefish" + + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + actual = MigrationServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "mussel", + "location": "winkle", + "model": "nautilus", + + } + path = MigrationServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_model_path(path) + assert expected == actual + +def test_version_path(): + project = "scallop" + model = "abalone" + version = "squid" + + expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + actual = MigrationServiceClient.version_path(project, model, version) + assert expected == actual + + +def test_parse_version_path(): + expected = { + "project": "clam", + "model": "whelk", + "version": "octopus", + + } + path = MigrationServiceClient.version_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_version_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "oyster" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = MigrationServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + + } + path = MigrationServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "cuttlefish" + + expected = "folders/{folder}".format(folder=folder, ) + actual = MigrationServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + + } + path = MigrationServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "winkle" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = MigrationServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + + } + path = MigrationServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "scallop" + + expected = "projects/{project}".format(project=project, ) + actual = MigrationServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + + } + path = MigrationServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "squid" + location = "clam" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = MigrationServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + + } + path = MigrationServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = MigrationServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py new file mode 100644 index 0000000000..bd93f3c4ee --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -0,0 +1,3799 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient +from google.cloud.aiplatform_v1beta1.services.model_service import pagers +from google.cloud.aiplatform_v1beta1.services.model_service import transports +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import env_var +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform_v1beta1.types import model_evaluation +from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice +from google.cloud.aiplatform_v1beta1.types import model_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ModelServiceClient._get_default_mtls_endpoint(None) is None + assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +def test_model_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_model_service_client_get_transport_class(): + transport = ModelServiceClient.get_transport_class() + assert transport == transports.ModelServiceGrpcTransport + + transport = ModelServiceClient.get_transport_class("grpc") + assert transport == transports.ModelServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_model_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = ModelServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UploadModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_upload_model_from_dict(): + test_upload_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_upload_model_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.UploadModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_upload_model_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UploadModelRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_upload_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UploadModelRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_upload_model_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.upload_model( + parent='parent_value', + model=gca_model.Model(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].model == gca_model.Model(name='name_value') + + +def test_upload_model_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.upload_model( + model_service.UploadModelRequest(), + parent='parent_value', + model=gca_model.Model(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_upload_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.upload_model( + parent='parent_value', + model=gca_model.Model(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].model == gca_model.Model(name='name_value') + + +@pytest.mark.asyncio +async def test_upload_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.upload_model( + model_service.UploadModelRequest(), + parent='parent_value', + model=gca_model.Model(name='name_value'), + ) + + +def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model( + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + + ) + + response = client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.training_pipeline == 'training_pipeline_value' + + assert response.artifact_uri == 'artifact_uri_value' + + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + + assert response.etag == 'etag_value' + + +def test_get_model_from_dict(): + test_get_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.GetModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) + + response = await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.training_pipeline == 'training_pipeline_value' + + assert response.artifact_uri == 'artifact_uri_value' + + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + + assert response.etag == 'etag_value' + + +def test_get_model_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + call.return_value = model.Model() + + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_model_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_model_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + + +def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_models_from_dict(): + test_list_models(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_models_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ListModelsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_models_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + call.return_value = model_service.ListModelsResponse() + + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_models_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + + await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_models_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_models( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_models_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_models_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_models( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_models_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_models( + model_service.ListModelsRequest(), + parent='parent_value', + ) + + +def test_list_models_pager(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_models(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model.Model) + for i in results) + +def test_list_models_pages(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = list(client.list_models(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_models_async_pager(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_models(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model.Model) + for i in responses) + +@pytest.mark.asyncio +async def test_list_models_async_pages(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_models(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model( + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + + ) + + response = client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UpdateModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_model.Model) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.training_pipeline == 'training_pipeline_value' + + assert response.artifact_uri == 'artifact_uri_value' + + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + + assert response.etag == 'etag_value' + + +def test_update_model_from_dict(): + test_update_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_model_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.UpdateModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) + + response = await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_model.Model) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' + + assert response.metadata_schema_uri == 'metadata_schema_uri_value' + + assert response.training_pipeline == 'training_pipeline_value' + + assert response.artifact_uri == 'artifact_uri_value' + + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + + assert response.etag == 'etag_value' + + +def test_update_model_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateModelRequest() + request.model.name = 'model.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: + call.return_value = gca_model.Model() + + client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateModelRequest() + request.model.name = 'model.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) + + await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] + + +def test_update_model_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_model( + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].model == gca_model.Model(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_model_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_model( + model_service.UpdateModelRequest(), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_model( + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].model == gca_model.Model(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_model( + model_service.UpdateModelRequest(), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.DeleteModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_model_from_dict(): + test_delete_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_model_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.DeleteModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_model_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteModelRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteModelRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_model_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_model_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_model( + model_service.DeleteModelRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_model( + model_service.DeleteModelRequest(), + name='name_value', + ) + + +def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ExportModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_model_from_dict(): + test_export_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_export_model_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ExportModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_model_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ExportModelRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_export_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ExportModelRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_export_model_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_model( + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + + +def test_export_model_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_model( + model_service.ExportModelRequest(), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + ) + + +@pytest.mark.asyncio +async def test_export_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_model( + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + + +@pytest.mark.asyncio +async def test_export_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_model( + model_service.ExportModelRequest(), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + ) + + +def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation( + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + + slice_dimensions=['slice_dimensions_value'], + + ) + + response = client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation.ModelEvaluation) + + assert response.name == 'name_value' + + assert response.metrics_schema_uri == 'metrics_schema_uri_value' + + assert response.slice_dimensions == ['slice_dimensions_value'] + + +def test_get_model_evaluation_from_dict(): + test_get_model_evaluation(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.GetModelEvaluationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + slice_dimensions=['slice_dimensions_value'], + )) + + response = await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation.ModelEvaluation) + + assert response.name == 'name_value' + + assert response.metrics_schema_uri == 'metrics_schema_uri_value' + + assert response.slice_dimensions == ['slice_dimensions_value'] + + +def test_get_model_evaluation_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), + '__call__') as call: + call.return_value = model_evaluation.ModelEvaluation() + + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + + await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_model_evaluation_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model_evaluation( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_model_evaluation_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation( + model_service.GetModelEvaluationRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model_evaluation( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_model_evaluation_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model_evaluation( + model_service.GetModelEvaluationRequest(), + name='name_value', + ) + + +def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_model_evaluations_from_dict(): + test_list_model_evaluations(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ListModelEvaluationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_model_evaluations_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + call.return_value = model_service.ListModelEvaluationsResponse() + + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_model_evaluations_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + + await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_model_evaluations_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_model_evaluations( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_model_evaluations_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluations( + model_service.ListModelEvaluationsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_model_evaluations( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_model_evaluations_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_model_evaluations( + model_service.ListModelEvaluationsRequest(), + parent='parent_value', + ) + + +def test_list_model_evaluations_pager(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], + next_page_token='def', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_model_evaluations(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in results) + +def test_list_model_evaluations_pages(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], + next_page_token='def', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluations(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_pager(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], + next_page_token='def', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_model_evaluations(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in responses) + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_pages(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], + next_page_token='def', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_model_evaluations(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice( + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + + ) + + response = client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationSliceRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) + + assert response.name == 'name_value' + + assert response.metrics_schema_uri == 'metrics_schema_uri_value' + + +def test_get_model_evaluation_slice_from_dict(): + test_get_model_evaluation_slice(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.GetModelEvaluationSliceRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + )) + + response = await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) + + assert response.name == 'name_value' + + assert response.metrics_schema_uri == 'metrics_schema_uri_value' + + +def test_get_model_evaluation_slice_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationSliceRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationSliceRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + + await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_model_evaluation_slice_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model_evaluation_slice( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_model_evaluation_slice_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation_slice( + model_service.GetModelEvaluationSliceRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model_evaluation_slice( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model_evaluation_slice( + model_service.GetModelEvaluationSliceRequest(), + name='name_value', + ) + + +def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationSlicesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationSlicesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_model_evaluation_slices_from_dict(): + test_list_model_evaluation_slices(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio'): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ListModelEvaluationSlicesRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_model_evaluation_slices_field_headers(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationSlicesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationSlicesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + + await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_model_evaluation_slices_flattened(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_model_evaluation_slices( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_model_evaluation_slices_flattened_error(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluation_slices( + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_flattened_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_model_evaluation_slices( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_model_evaluation_slices( + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', + ) + + +def test_list_model_evaluation_slices_pager(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], + next_page_token='def', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_model_evaluation_slices(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in results) + +def test_list_model_evaluation_slices_pages(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], + next_page_token='def', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluation_slices(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_pager(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], + next_page_token='def', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_model_evaluation_slices(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses) + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_pages(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='abc', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], + next_page_token='def', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token='ghi', + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_model_evaluation_slices(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = ModelServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ModelServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelServiceGrpcTransport, + ) + + +def test_model_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.ModelServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_model_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.ModelServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'upload_model', + 'get_model', + 'list_models', + 'update_model', + 'delete_model', + 'export_model', + 'get_model_evaluation', + 'list_model_evaluations', + 'get_model_evaluation_slice', + 'list_model_evaluation_slices', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_model_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_model_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport() + adc.assert_called_once() + + +def test_model_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + ModelServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_model_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_model_service_host_no_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_model_service_host_with_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_model_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_model_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_service_grpc_lro_client(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_service_grpc_lro_async_client(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + actual = ModelServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + + } + path = ModelServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_endpoint_path(path) + assert expected == actual + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + actual = ModelServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + + } + path = ModelServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_path(path) + assert expected == actual + +def test_model_evaluation_path(): + project = "squid" + location = "clam" + model = "whelk" + evaluation = "octopus" + + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) + assert expected == actual + + +def test_parse_model_evaluation_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", + + } + path = ModelServiceClient.model_evaluation_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_evaluation_path(path) + assert expected == actual + +def test_model_evaluation_slice_path(): + project = "winkle" + location = "nautilus" + model = "scallop" + evaluation = "abalone" + slice = "squid" + + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) + assert expected == actual + + +def test_parse_model_evaluation_slice_path(): + expected = { + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", + + } + path = ModelServiceClient.model_evaluation_slice_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_evaluation_slice_path(path) + assert expected == actual + +def test_training_pipeline_path(): + project = "cuttlefish" + location = "mussel" + training_pipeline = "winkle" + + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) + assert expected == actual + + +def test_parse_training_pipeline_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", + + } + path = ModelServiceClient.training_pipeline_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_training_pipeline_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = ModelServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + + } + path = ModelServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder, ) + actual = ModelServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + + } + path = ModelServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = ModelServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + + } + path = ModelServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project, ) + actual = ModelServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + + } + path = ModelServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = ModelServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + + } + path = ModelServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = ModelServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py new file mode 100644 index 0000000000..82f1b4f546 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -0,0 +1,2160 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceClient +from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import env_var +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_service +from google.cloud.aiplatform_v1beta1.types import pipeline_state +from google.cloud.aiplatform_v1beta1.types import training_pipeline +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as gp_any # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PipelineServiceClient._get_default_mtls_endpoint(None) is None + assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [PipelineServiceClient, PipelineServiceAsyncClient]) +def test_pipeline_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_pipeline_service_client_get_transport_class(): + transport = PipelineServiceClient.get_transport_class() + assert transport == transports.PipelineServiceGrpcTransport + + transport = PipelineServiceClient.get_transport_class("grpc") + assert transport == transports.PipelineServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) +def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_pipeline_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = PipelineServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline( + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + + ) + + response = client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_training_pipeline.TrainingPipeline) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.training_task_definition == 'training_task_definition_value' + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_create_training_pipeline_from_dict(): + test_create_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio'): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.CreateTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) + + response = await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_training_pipeline.TrainingPipeline) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.training_task_definition == 'training_task_definition_value' + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_create_training_pipeline_field_headers(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreateTrainingPipelineRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), + '__call__') as call: + call.return_value = gca_training_pipeline.TrainingPipeline() + + client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreateTrainingPipelineRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + + await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_training_pipeline_flattened(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_training_pipeline( + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + + +def test_create_training_pipeline_flattened_error(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_training_pipeline( + pipeline_service.CreateTrainingPipelineRequest(), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_training_pipeline( + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + + +@pytest.mark.asyncio +async def test_create_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_training_pipeline( + pipeline_service.CreateTrainingPipelineRequest(), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + ) + + +def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline( + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + + ) + + response = client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, training_pipeline.TrainingPipeline) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.training_task_definition == 'training_task_definition_value' + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_get_training_pipeline_from_dict(): + test_get_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio'): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.GetTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) + + response = await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, training_pipeline.TrainingPipeline) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.training_task_definition == 'training_task_definition_value' + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_get_training_pipeline_field_headers(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetTrainingPipelineRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), + '__call__') as call: + call.return_value = training_pipeline.TrainingPipeline() + + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetTrainingPipelineRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + + await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_training_pipeline_flattened(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_training_pipeline( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_training_pipeline_flattened_error(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_training_pipeline( + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_training_pipeline( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_training_pipeline( + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', + ) + + +def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTrainingPipelinesPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_training_pipelines_from_dict(): + test_list_training_pipelines(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio'): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.ListTrainingPipelinesRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_training_pipelines_field_headers(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListTrainingPipelinesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_training_pipelines_field_headers_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListTrainingPipelinesRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + + await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_training_pipelines_flattened(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_training_pipelines( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_training_pipelines_flattened_error(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_training_pipelines( + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_flattened_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_training_pipelines( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_training_pipelines_flattened_error_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_training_pipelines( + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', + ) + + +def test_list_training_pipelines_pager(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token='abc', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], + next_page_token='def', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_training_pipelines(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in results) + +def test_list_training_pipelines_pages(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token='abc', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], + next_page_token='def', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + pages = list(client.list_training_pipelines(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_pager(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token='abc', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], + next_page_token='def', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_training_pipelines(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in responses) + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_pages(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token='abc', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], + next_page_token='def', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_training_pipelines(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_training_pipeline_from_dict(): + test_delete_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio'): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.DeleteTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_training_pipeline_field_headers(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeleteTrainingPipelineRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeleteTrainingPipelineRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_training_pipeline_flattened(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_training_pipeline( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_training_pipeline_flattened_error(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_training_pipeline( + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_training_pipeline( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_training_pipeline( + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', + ) + + +def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_training_pipeline_from_dict(): + test_cancel_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio'): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.CancelTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_training_pipeline_field_headers(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelTrainingPipelineRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), + '__call__') as call: + call.return_value = None + + client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelTrainingPipelineRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_cancel_training_pipeline_flattened(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_training_pipeline( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_cancel_training_pipeline_flattened_error(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_training_pipeline( + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_training_pipeline( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_training_pipeline( + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = PipelineServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PipelineServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PipelineServiceGrpcTransport, + ) + + +def test_pipeline_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.PipelineServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_pipeline_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.PipelineServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_training_pipeline', + 'get_training_pipeline', + 'list_training_pipelines', + 'delete_training_pipeline', + 'cancel_training_pipeline', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_pipeline_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PipelineServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_pipeline_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PipelineServiceTransport() + adc.assert_called_once() + + +def test_pipeline_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + PipelineServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_pipeline_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_pipeline_service_host_no_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_pipeline_service_host_with_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_pipeline_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_pipeline_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.PipelineServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +def test_pipeline_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +def test_pipeline_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_pipeline_service_grpc_lro_client(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_pipeline_service_grpc_lro_async_client(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + actual = PipelineServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + + } + path = PipelineServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_endpoint_path(path) + assert expected == actual + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + actual = PipelineServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + + } + path = PipelineServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_model_path(path) + assert expected == actual + +def test_training_pipeline_path(): + project = "squid" + location = "clam" + training_pipeline = "whelk" + + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) + assert expected == actual + + +def test_parse_training_pipeline_path(): + expected = { + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", + + } + path = PipelineServiceClient.training_pipeline_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_training_pipeline_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = PipelineServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + + } + path = PipelineServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder, ) + actual = PipelineServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + + } + path = PipelineServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = PipelineServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + + } + path = PipelineServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project, ) + actual = PipelineServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + + } + path = PipelineServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = PipelineServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + + } + path = PipelineServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = PipelineServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py new file mode 100644 index 0000000000..9934ffb497 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -0,0 +1,1220 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.prediction_service import PredictionServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.prediction_service import PredictionServiceClient +from google.cloud.aiplatform_v1beta1.services.prediction_service import transports +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.oauth2 import service_account +from google.protobuf import struct_pb2 as struct # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PredictionServiceClient._get_default_mtls_endpoint(None) is None + assert PredictionServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert PredictionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert PredictionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert PredictionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert PredictionServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [PredictionServiceClient, PredictionServiceAsyncClient]) +def test_prediction_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_prediction_service_client_get_transport_class(): + transport = PredictionServiceClient.get_transport_class() + assert transport == transports.PredictionServiceGrpcTransport + + transport = PredictionServiceClient.get_transport_class("grpc") + assert transport == transports.PredictionServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(PredictionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceClient)) +@mock.patch.object(PredictionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceAsyncClient)) +def test_prediction_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PredictionServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PredictionServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc", "true"), + (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc", "false"), + (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(PredictionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceClient)) +@mock.patch.object(PredictionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_prediction_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_prediction_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_prediction_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_prediction_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = PredictionServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_predict(transport: str = 'grpc', request_type=prediction_service.PredictRequest): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.predict), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse( + deployed_model_id='deployed_model_id_value', + + ) + + response = client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == prediction_service.PredictRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) + + assert response.deployed_model_id == 'deployed_model_id_value' + + +def test_predict_from_dict(): + test_predict(request_type=dict) + + +@pytest.mark.asyncio +async def test_predict_async(transport: str = 'grpc_asyncio'): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = prediction_service.PredictRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.predict), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.PredictResponse( + deployed_model_id='deployed_model_id_value', + )) + + response = await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) + + assert response.deployed_model_id == 'deployed_model_id_value' + + +def test_predict_field_headers(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.PredictRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.predict), + '__call__') as call: + call.return_value = prediction_service.PredictResponse() + + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_predict_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.PredictRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.predict), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.PredictResponse()) + + await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +def test_predict_flattened(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.predict), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.predict( + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + + +def test_predict_flattened_error(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.predict( + prediction_service.PredictRequest(), + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + +@pytest.mark.asyncio +async def test_predict_flattened_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.predict), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.PredictResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.predict( + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + + assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + + +@pytest.mark.asyncio +async def test_predict_flattened_error_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.predict( + prediction_service.PredictRequest(), + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + +def test_explain(transport: str = 'grpc', request_type=prediction_service.ExplainRequest): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.explain), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse( + deployed_model_id='deployed_model_id_value', + + ) + + response = client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == prediction_service.ExplainRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.ExplainResponse) + + assert response.deployed_model_id == 'deployed_model_id_value' + + +def test_explain_from_dict(): + test_explain(request_type=dict) + + +@pytest.mark.asyncio +async def test_explain_async(transport: str = 'grpc_asyncio'): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = prediction_service.ExplainRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.explain), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.ExplainResponse( + deployed_model_id='deployed_model_id_value', + )) + + response = await client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.ExplainResponse) + + assert response.deployed_model_id == 'deployed_model_id_value' + + +def test_explain_field_headers(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.ExplainRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.explain), + '__call__') as call: + call.return_value = prediction_service.ExplainResponse() + + client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_explain_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.ExplainRequest() + request.endpoint = 'endpoint/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.explain), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.ExplainResponse()) + + await client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] + + +def test_explain_flattened(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.explain), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.explain( + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id='deployed_model_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + + assert args[0].deployed_model_id == 'deployed_model_id_value' + + +def test_explain_flattened_error(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.explain( + prediction_service.ExplainRequest(), + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id='deployed_model_id_value', + ) + + +@pytest.mark.asyncio +async def test_explain_flattened_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.explain), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.ExplainResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.explain( + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id='deployed_model_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == 'endpoint_value' + + assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + + assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + + assert args[0].deployed_model_id == 'deployed_model_id_value' + + +@pytest.mark.asyncio +async def test_explain_flattened_error_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.explain( + prediction_service.ExplainRequest(), + endpoint='endpoint_value', + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id='deployed_model_id_value', + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = PredictionServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PredictionServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PredictionServiceGrpcTransport, + ) + + +def test_prediction_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.PredictionServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_prediction_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.PredictionServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'predict', + 'explain', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + +def test_prediction_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PredictionServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_prediction_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PredictionServiceTransport() + adc.assert_called_once() + + +def test_prediction_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + PredictionServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_prediction_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.PredictionServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_prediction_service_host_no_port(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_prediction_service_host_with_port(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_prediction_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.PredictionServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_prediction_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.PredictionServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.PredictionServiceGrpcTransport, transports.PredictionServiceGrpcAsyncIOTransport]) +def test_prediction_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.PredictionServiceGrpcTransport, transports.PredictionServiceGrpcAsyncIOTransport]) +def test_prediction_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + actual = PredictionServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + + } + path = PredictionServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_endpoint_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = PredictionServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + + } + path = PredictionServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder, ) + actual = PredictionServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + + } + path = PredictionServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = PredictionServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + + } + path = PredictionServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project, ) + actual = PredictionServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + + } + path = PredictionServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = PredictionServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + + } + path = PredictionServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.PredictionServiceTransport, '_prep_wrapped_messages') as prep: + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.PredictionServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = PredictionServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py new file mode 100644 index 0000000000..66d80fad2a --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -0,0 +1,2096 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceClient +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient]) +def test_specialist_pool_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_specialist_pool_service_client_get_transport_class(): + transport = SpecialistPoolServiceClient.get_transport_class() + assert transport == transports.SpecialistPoolServiceGrpcTransport + + transport = SpecialistPoolServiceClient.get_transport_class("grpc") + assert transport == transports.SpecialistPoolServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) +def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + ssl_channel_creds = mock.Mock() + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ssl_credentials_mock.return_value + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_specialist_pool_service_client_client_options_from_dict(): + with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = SpecialistPoolServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_specialist_pool_from_dict(): + test_create_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio'): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.CreateSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.CreateSpecialistPoolRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.CreateSpecialistPoolRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_create_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_specialist_pool( + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + + +def test_create_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_specialist_pool( + specialist_pool_service.CreateSpecialistPoolRequest(), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + ) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_specialist_pool( + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + + +@pytest.mark.asyncio +async def test_create_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_specialist_pool( + specialist_pool_service.CreateSpecialistPoolRequest(), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + ) + + +def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool( + name='name_value', + + display_name='display_name_value', + + specialist_managers_count=2662, + + specialist_manager_emails=['specialist_manager_emails_value'], + + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + + ) + + response = client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, specialist_pool.SpecialistPool) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.specialist_managers_count == 2662 + + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + + +def test_get_specialist_pool_from_dict(): + test_get_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio'): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.GetSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( + name='name_value', + display_name='display_name_value', + specialist_managers_count=2662, + specialist_manager_emails=['specialist_manager_emails_value'], + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + )) + + response = await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, specialist_pool.SpecialistPool) + + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.specialist_managers_count == 2662 + + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + + +def test_get_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.GetSpecialistPoolRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), + '__call__') as call: + call.return_value = specialist_pool.SpecialistPool() + + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.GetSpecialistPoolRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + + await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_get_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_specialist_pool( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_get_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_specialist_pool( + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_get_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_specialist_pool( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_get_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_specialist_pool( + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', + ) + + +def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token='next_page_token_value', + + ) + + response = client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSpecialistPoolsPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_specialist_pools_from_dict(): + test_list_specialist_pools(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio'): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.ListSpecialistPoolsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token='next_page_token_value', + )) + + response = await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) + + assert response.next_page_token == 'next_page_token_value' + + +def test_list_specialist_pools_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.ListSpecialistPoolsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_specialist_pools_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.ListSpecialistPoolsRequest() + request.parent = 'parent/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + + await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] + + +def test_list_specialist_pools_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_specialist_pools( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +def test_list_specialist_pools_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_specialist_pools( + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', + ) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_specialist_pools( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == 'parent_value' + + +@pytest.mark.asyncio +async def test_list_specialist_pools_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_specialist_pools( + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', + ) + + +def test_list_specialist_pools_pager(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token='abc', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], + next_page_token='def', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_specialist_pools(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in results) + +def test_list_specialist_pools_pages(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token='abc', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], + next_page_token='def', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + pages = list(client.list_specialist_pools(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_pager(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token='abc', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], + next_page_token='def', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_specialist_pools(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in responses) + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_pages(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token='abc', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], + next_page_token='def', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_specialist_pools(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_specialist_pool_from_dict(): + test_delete_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio'): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + request.name = 'name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] + + +def test_delete_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_specialist_pool( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +def test_delete_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_specialist_pool( + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', + ) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_specialist_pool( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == 'name_value' + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_specialist_pool( + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', + ) + + +def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + + response = client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_specialist_pool_from_dict(): + test_update_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio'): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + + response = await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + request.specialist_pool.name = 'specialist_pool.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') + + client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + request.specialist_pool.name = 'specialist_pool.name/value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + + await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] + + +def test_update_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_specialist_pool( + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +def test_update_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_specialist_pool( + specialist_pool_service.UpdateSpecialistPoolRequest(), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_specialist_pool( + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_specialist_pool( + specialist_pool_service.UpdateSpecialistPoolRequest(), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = SpecialistPoolServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize("transport_class", [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.SpecialistPoolServiceGrpcTransport, + ) + + +def test_specialist_pool_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.SpecialistPoolServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_specialist_pool_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.SpecialistPoolServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_specialist_pool', + 'get_specialist_pool', + 'list_specialist_pools', + 'delete_specialist_pool', + 'update_specialist_pool', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_specialist_pool_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.SpecialistPoolServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + quota_project_id="octopus", + ) + + +def test_specialist_pool_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.SpecialistPoolServiceTransport() + adc.assert_called_once() + + +def test_specialist_pool_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + SpecialistPoolServiceClient() + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id=None, + ) + + +def test_specialist_pool_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), + quota_project_id="octopus", + ) + +def test_specialist_pool_service_host_no_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:443' + + +def test_specialist_pool_service_host_with_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + ) + assert client.transport._host == 'aiplatform.googleapis.com:8000' + + +def test_specialist_pool_service_grpc_transport_channel(): + channel = grpc.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +def test_specialist_pool_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel('http://localhost/') + + # Check that channel is used if provided. + transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + + +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +def test_specialist_pool_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_specialist_pool_service_grpc_lro_client(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_specialist_pool_service_grpc_lro_async_client(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + +def test_specialist_pool_path(): + project = "squid" + location = "clam" + specialist_pool = "whelk" + + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) + assert expected == actual + + +def test_parse_specialist_pool_path(): + expected = { + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", + + } + path = SpecialistPoolServiceClient.specialist_pool_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + + } + path = SpecialistPoolServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder, ) + actual = SpecialistPoolServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + + } + path = SpecialistPoolServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization, ) + actual = SpecialistPoolServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + + } + path = SpecialistPoolServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project, ) + actual = SpecialistPoolServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + + } + path = SpecialistPoolServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = SpecialistPoolServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + + } + path = SpecialistPoolServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = SpecialistPoolServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/test_dataset_service.py b/tests/unit/gapic/test_dataset_service.py deleted file mode 100644 index bdf4410884..0000000000 --- a/tests/unit/gapic/test_dataset_service.py +++ /dev/null @@ -1,1140 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers -from google.cloud.aiplatform_v1beta1.services.dataset_service import transports -from google.cloud.aiplatform_v1beta1.types import annotation -from google.cloud.aiplatform_v1beta1.types import annotation_spec -from google.cloud.aiplatform_v1beta1.types import data_item -from google.cloud.aiplatform_v1beta1.types import dataset -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import dataset_service -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def test_dataset_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = DatasetServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = DatasetServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_dataset_service_client_client_options(): - # Check the default options have their expected values. - assert ( - DatasetServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.DatasetServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = DatasetServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_dataset_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.DatasetServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = DatasetServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.CreateDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - - -def test_create_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_dataset( - dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), - ) - - -def test_get_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - - response = client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == "etag_value" - - -def test_get_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetDatasetRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - call.return_value = dataset.Dataset() - client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset.Dataset() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", - ) - - -def test_update_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.UpdateDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - - response = client.update_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == "etag_value" - - -def test_update_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_dataset.Dataset() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_dataset( - dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_list_datasets(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDatasetsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_datasets_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDatasetsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - call.return_value = dataset_service.ListDatasetsResponse() - client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_datasets_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_datasets(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_datasets_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", - ) - - -def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - results = [i for i in client.list_datasets(request={},)] - assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) for i in results) - - -def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - pages = list(client.list_datasets(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.DeleteDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", - ) - - -def test_import_data(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ImportDataRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.import_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.import_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_import_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.import_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] - - -def test_import_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.import_data( - dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - -def test_export_data(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ExportDataRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.export_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) - - -def test_export_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.export_data( - dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - -def test_list_data_items(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDataItemsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_items_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDataItemsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - call.return_value = dataset_service.ListDataItemsResponse() - client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_data_items_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_data_items(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_data_items_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", - ) - - -def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - results = [i for i in client.list_data_items(request={},)] - assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) for i in results) - - -def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - pages = list(client.list_data_items(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_get_annotation_spec(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetAnnotationSpecRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", - ) - - response = client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - - -def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetAnnotationSpecRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - call.return_value = annotation_spec.AnnotationSpec() - client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_annotation_spec_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_annotation_spec(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", - ) - - -def test_list_annotations(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListAnnotationsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_annotations_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListAnnotationsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - call.return_value = dataset_service.ListAnnotationsResponse() - client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_annotations_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_annotations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_annotations_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", - ) - - -def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - results = [i for i in client.list_annotations(request={},)] - assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) for i in results) - - -def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - pages = list(client.list_annotations(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = DatasetServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.DatasetServiceGrpcTransport,) - - -def test_dataset_service_base_transport(): - # Instantiate the base transport. - transport = transports.DatasetServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_dataset", - "get_dataset", - "update_dataset", - "list_datasets", - "delete_dataset", - "import_data", - "export_data", - "list_data_items", - "get_annotation_spec", - "list_annotations", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_dataset_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - DatasetServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_dataset_service_host_no_port(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_dataset_service_host_with_port(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_dataset_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.DatasetServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_dataset_service_grpc_lro_client(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_dataset_path(): - project = "squid" - location = "clam" - dataset = "whelk" - - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) - actual = DatasetServiceClient.dataset_path(project, location, dataset) - assert expected == actual diff --git a/tests/unit/gapic/test_endpoint_service.py b/tests/unit/gapic/test_endpoint_service.py deleted file mode 100644 index 4059cdb819..0000000000 --- a/tests/unit/gapic/test_endpoint_service.py +++ /dev/null @@ -1,771 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers -from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports -from google.cloud.aiplatform_v1beta1.types import accelerator_type -from google.cloud.aiplatform_v1beta1.types import endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint_service -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def test_endpoint_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = EndpointServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = EndpointServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_endpoint_service_client_client_options(): - # Check the default options have their expected values. - assert ( - EndpointServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.EndpointServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = EndpointServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_endpoint_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.EndpointServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = EndpointServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.CreateEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - - -def test_create_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_endpoint( - endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - -def test_get_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.GetEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - - response = client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.etag == "etag_value" - - -def test_get_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.GetEndpointRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - call.return_value = endpoint.Endpoint() - client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint.Endpoint() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", - ) - - -def test_list_endpoints(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.ListEndpointsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_endpoints_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.ListEndpointsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - call.return_value = endpoint_service.ListEndpointsResponse() - client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_endpoints_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint_service.ListEndpointsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_endpoints(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_endpoints_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", - ) - - -def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - results = [i for i in client.list_endpoints(request={},)] - assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in results) - - -def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - pages = list(client.list_endpoints(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_update_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UpdateEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - - response = client.update_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.etag == "etag_value" - - -def test_update_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_endpoint.Endpoint() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_endpoint( - endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_delete_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeleteEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", - ) - - -def test_deploy_model(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeployModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.deploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_deploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) - assert args[0].traffic_split == {"key_value": 541} - - -def test_deploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.deploy_model( - endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - -def test_undeploy_model(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UndeployModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.undeploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_undeploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {"key_value": 541} - - -def test_undeploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.undeploy_model( - endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = EndpointServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.EndpointServiceGrpcTransport,) - - -def test_endpoint_service_base_transport(): - # Instantiate the base transport. - transport = transports.EndpointServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_endpoint", - "get_endpoint", - "list_endpoints", - "update_endpoint", - "delete_endpoint", - "deploy_model", - "undeploy_model", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_endpoint_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - EndpointServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_endpoint_service_host_no_port(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_endpoint_service_host_with_port(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_endpoint_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.EndpointServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_endpoint_service_grpc_lro_client(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_endpoint_path(): - project = "squid" - location = "clam" - endpoint = "whelk" - - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) - actual = EndpointServiceClient.endpoint_path(project, location, endpoint) - assert expected == actual diff --git a/tests/unit/gapic/test_job_service.py b/tests/unit/gapic/test_job_service.py deleted file mode 100644 index 92ed0d37e3..0000000000 --- a/tests/unit/gapic/test_job_service.py +++ /dev/null @@ -1,2118 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient -from google.cloud.aiplatform_v1beta1.services.job_service import pagers -from google.cloud.aiplatform_v1beta1.services.job_service import transports -from google.cloud.aiplatform_v1beta1.types import accelerator_type -from google.cloud.aiplatform_v1beta1.types import ( - accelerator_type as gca_accelerator_type, -) -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) -from google.cloud.aiplatform_v1beta1.types import completion_stats -from google.cloud.aiplatform_v1beta1.types import ( - completion_stats as gca_completion_stats, -) -from google.cloud.aiplatform_v1beta1.types import custom_job -from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import job_service -from google.cloud.aiplatform_v1beta1.types import job_state -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters -from google.cloud.aiplatform_v1beta1.types import ( - manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, -) -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import study -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import any_pb2 as any # type: ignore -from google.protobuf import duration_pb2 as duration # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore -from google.type import money_pb2 as money # type: ignore - - -def test_job_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = JobServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = JobServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_job_service_client_client_options(): - # Check the default options have their expected values. - assert JobServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = JobServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_job_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_custom_job.CustomJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") - - -def test_create_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_custom_job( - job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - -def test_get_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetCustomJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - call.return_value = custom_job.CustomJob() - client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = custom_job.CustomJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", - ) - - -def test_list_custom_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListCustomJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListCustomJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_custom_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListCustomJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - call.return_value = job_service.ListCustomJobsResponse() - client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_custom_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListCustomJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_custom_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_custom_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", - ) - - -def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - results = [i for i in client.list_custom_jobs(request={},)] - assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in results) - - -def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - pages = list(client.list_custom_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", - ) - - -def test_cancel_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", - ) - - -def test_create_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - - response = client.create_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.datasets == ["datasets_value"] - assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == "inputs_schema_uri_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_create_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_data_labeling_job.DataLabelingJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) - - -def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_data_labeling_job( - job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - -def test_get_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - - response = client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.datasets == ["datasets_value"] - assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == "inputs_schema_uri_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_get_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetDataLabelingJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - call.return_value = data_labeling_job.DataLabelingJob() - client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = data_labeling_job.DataLabelingJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", - ) - - -def test_list_data_labeling_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListDataLabelingJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListDataLabelingJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - call.return_value = job_service.ListDataLabelingJobsResponse() - client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListDataLabelingJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_data_labeling_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", - ) - - -def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_data_labeling_jobs(request={},)] - assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) - - -def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_data_labeling_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", - ) - - -def test_cancel_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", - ) - - -def test_create_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.max_trial_count == 1609 - assert response.parallel_trial_count == 2128 - assert response.max_failed_trial_count == 2317 - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) - - -def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_hyperparameter_tuning_job( - job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - -def test_get_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.max_trial_count == 1609 - assert response.parallel_trial_count == 2128 - assert response.max_failed_trial_count == 2317 - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetHyperparameterTuningJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_list_hyperparameter_tuning_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListHyperparameterTuningJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListHyperparameterTuningJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - call.return_value = job_service.ListHyperparameterTuningJobsResponse() - client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListHyperparameterTuningJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_hyperparameter_tuning_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", - ) - - -def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_hyperparameter_tuning_jobs(request={},)] - assert len(results) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results - ) - - -def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_cancel_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_create_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.model == "model_value" - - assert response.generate_explanation is True - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_batch_prediction_job.BatchPredictionJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) - - -def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_batch_prediction_job( - job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - -def test_get_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.model == "model_value" - - assert response.generate_explanation is True - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetBatchPredictionJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - call.return_value = batch_prediction_job.BatchPredictionJob() - client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = batch_prediction_job.BatchPredictionJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", - ) - - -def test_list_batch_prediction_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListBatchPredictionJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListBatchPredictionJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - call.return_value = job_service.ListBatchPredictionJobsResponse() - client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListBatchPredictionJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_batch_prediction_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", - ) - - -def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_batch_prediction_jobs(request={},)] - assert len(results) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results - ) - - -def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", - ) - - -def test_cancel_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = JobServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.JobServiceGrpcTransport,) - - -def test_job_service_base_transport(): - # Instantiate the base transport. - transport = transports.JobServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_custom_job", - "get_custom_job", - "list_custom_jobs", - "delete_custom_job", - "cancel_custom_job", - "create_data_labeling_job", - "get_data_labeling_job", - "list_data_labeling_jobs", - "delete_data_labeling_job", - "cancel_data_labeling_job", - "create_hyperparameter_tuning_job", - "get_hyperparameter_tuning_job", - "list_hyperparameter_tuning_jobs", - "delete_hyperparameter_tuning_job", - "cancel_hyperparameter_tuning_job", - "create_batch_prediction_job", - "get_batch_prediction_job", - "list_batch_prediction_jobs", - "delete_batch_prediction_job", - "cancel_batch_prediction_job", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_job_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - JobServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_job_service_host_no_port(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_job_service_host_with_port(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_job_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.JobServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_job_service_grpc_lro_client(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_custom_job_path(): - project = "squid" - location = "clam" - custom_job = "whelk" - - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) - actual = JobServiceClient.custom_job_path(project, location, custom_job) - assert expected == actual - - -def test_batch_prediction_job_path(): - project = "squid" - location = "clam" - batch_prediction_job = "whelk" - - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, location=location, batch_prediction_job=batch_prediction_job, - ) - actual = JobServiceClient.batch_prediction_job_path( - project, location, batch_prediction_job - ) - assert expected == actual - - -def test_hyperparameter_tuning_job_path(): - project = "squid" - location = "clam" - hyperparameter_tuning_job = "whelk" - - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) - actual = JobServiceClient.hyperparameter_tuning_job_path( - project, location, hyperparameter_tuning_job - ) - assert expected == actual - - -def test_data_labeling_job_path(): - project = "squid" - location = "clam" - data_labeling_job = "whelk" - - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) - actual = JobServiceClient.data_labeling_job_path( - project, location, data_labeling_job - ) - assert expected == actual diff --git a/tests/unit/gapic/test_model_service.py b/tests/unit/gapic/test_model_service.py deleted file mode 100644 index 854272f8a5..0000000000 --- a/tests/unit/gapic/test_model_service.py +++ /dev/null @@ -1,1223 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient -from google.cloud.aiplatform_v1beta1.services.model_service import pagers -from google.cloud.aiplatform_v1beta1.services.model_service import transports -from google.cloud.aiplatform_v1beta1.types import deployed_model_ref -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import model_evaluation -from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice -from google.cloud.aiplatform_v1beta1.types import model_service -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def test_model_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = ModelServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = ModelServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_model_service_client_client_options(): - # Check the default options have their expected values. - assert ( - ModelServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.ModelServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = ModelServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_model_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.ModelServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_upload_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.UploadModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.upload_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.upload_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_upload_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.upload_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name="name_value") - - -def test_upload_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.upload_model( - model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), - ) - - -def test_get_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", - ) - - response = client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model.Model) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - assert response.etag == "etag_value" - - -def test_get_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - call.return_value = model.Model() - client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model( - model_service.GetModelRequest(), name="name_value", - ) - - -def test_list_models(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_models_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - call.return_value = model_service.ListModelsResponse() - client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_models_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_models(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_models_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_models( - model_service.ListModelsRequest(), parent="parent_value", - ) - - -def test_list_models_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - results = [i for i in client.list_models(request={},)] - assert len(results) == 6 - assert all(isinstance(i, model.Model) for i in results) - - -def test_list_models_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - pages = list(client.list_models(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_update_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.UpdateModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", - ) - - response = client.update_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_model.Model) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - assert response.etag == "etag_value" - - -def test_update_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_model.Model() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_model( - model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_delete_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.DeleteModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_model( - model_service.DeleteModelRequest(), name="name_value", - ) - - -def test_export_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ExportModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.export_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) - - -def test_export_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.export_model( - model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - -def test_get_model_evaluation(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], - ) - - response = client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" - assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ["slice_dimensions_value"] - - -def test_get_model_evaluation_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - call.return_value = model_evaluation.ModelEvaluation() - client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_evaluation_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation.ModelEvaluation() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_model_evaluation(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", - ) - - -def test_list_model_evaluations(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluations_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - call.return_value = model_service.ListModelEvaluationsResponse() - client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_model_evaluations_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_model_evaluations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", - ) - - -def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_model_evaluations(request={},)] - assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) - - -def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - pages = list(client.list_model_evaluations(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_get_model_evaluation_slice(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationSliceRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", - ) - - response = client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" - assert response.metrics_schema_uri == "metrics_schema_uri_value" - - -def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationSliceRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - call.return_value = model_evaluation_slice.ModelEvaluationSlice() - client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation_slice.ModelEvaluationSlice() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_model_evaluation_slice(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", - ) - - -def test_list_model_evaluation_slices(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationSlicesRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationSlicesRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - call.return_value = model_service.ListModelEvaluationSlicesResponse() - client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationSlicesResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_model_evaluation_slices(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", - ) - - -def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_model_evaluation_slices(request={},)] - assert len(results) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results - ) - - -def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - pages = list(client.list_model_evaluation_slices(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = ModelServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.ModelServiceGrpcTransport,) - - -def test_model_service_base_transport(): - # Instantiate the base transport. - transport = transports.ModelServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "upload_model", - "get_model", - "list_models", - "update_model", - "delete_model", - "export_model", - "get_model_evaluation", - "list_model_evaluations", - "get_model_evaluation_slice", - "list_model_evaluation_slices", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_model_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - ModelServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_model_service_host_no_port(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_model_service_host_with_port(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_model_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.ModelServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_model_service_grpc_lro_client(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" - - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) - actual = ModelServiceClient.model_path(project, location, model) - assert expected == actual diff --git a/tests/unit/gapic/test_pipeline_service.py b/tests/unit/gapic/test_pipeline_service.py deleted file mode 100644 index c7c2db4449..0000000000 --- a/tests/unit/gapic/test_pipeline_service.py +++ /dev/null @@ -1,675 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers -from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports -from google.cloud.aiplatform_v1beta1.types import deployed_model_ref -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import pipeline_service -from google.cloud.aiplatform_v1beta1.types import pipeline_state -from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import any_pb2 as any # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore - - -def test_pipeline_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = PipelineServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = PipelineServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_pipeline_service_client_client_options(): - # Check the default options have their expected values. - assert ( - PipelineServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.PipelineServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PipelineServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_pipeline_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.PipelineServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PipelineServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CreateTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - - response = client.create_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.training_task_definition == "training_task_definition_value" - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_create_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_training_pipeline.TrainingPipeline() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) - - -def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_training_pipeline( - pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - -def test_get_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.GetTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - - response = client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.training_task_definition == "training_task_definition_value" - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.GetTrainingPipelineRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - call.return_value = training_pipeline.TrainingPipeline() - client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = training_pipeline.TrainingPipeline() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", - ) - - -def test_list_training_pipelines(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.ListTrainingPipelinesRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.ListTrainingPipelinesRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - call.return_value = pipeline_service.ListTrainingPipelinesResponse() - client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_training_pipelines_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = pipeline_service.ListTrainingPipelinesResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_training_pipelines(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", - ) - - -def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_training_pipelines(request={},)] - assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) - - -def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - pages = list(client.list_training_pipelines(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.DeleteTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", - ) - - -def test_cancel_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CancelTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = PipelineServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.PipelineServiceGrpcTransport,) - - -def test_pipeline_service_base_transport(): - # Instantiate the base transport. - transport = transports.PipelineServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_training_pipeline", - "get_training_pipeline", - "list_training_pipelines", - "delete_training_pipeline", - "cancel_training_pipeline", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_pipeline_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - PipelineServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_pipeline_service_host_no_port(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_pipeline_service_host_with_port(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_pipeline_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.PipelineServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_pipeline_service_grpc_lro_client(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" - - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) - actual = PipelineServiceClient.model_path(project, location, model) - assert expected == actual - - -def test_training_pipeline_path(): - project = "squid" - location = "clam" - training_pipeline = "whelk" - - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = PipelineServiceClient.training_pipeline_path( - project, location, training_pipeline - ) - assert expected == actual diff --git a/tests/unit/gapic/test_prediction_service.py b/tests/unit/gapic/test_prediction_service.py deleted file mode 100644 index b23ab9b529..0000000000 --- a/tests/unit/gapic/test_prediction_service.py +++ /dev/null @@ -1,311 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.prediction_service import ( - PredictionServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.prediction_service import transports -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import prediction_service -from google.oauth2 import service_account -from google.protobuf import struct_pb2 as struct # type: ignore - - -def test_prediction_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = PredictionServiceClient.from_service_account_file( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - client = PredictionServiceClient.from_service_account_json( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_prediction_service_client_client_options(): - # Check the default options have their expected values. - assert ( - PredictionServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PredictionServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_prediction_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PredictionServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_predict(transport: str = "grpc"): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.PredictRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse( - deployed_model_id="deployed_model_id_value", - ) - - response = client.predict(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.PredictResponse) - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_predict_flattened(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.predict( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct.Value( - # null_value=struct.NullValue.NULL_VALUE - # ) - - -def test_predict_flattened_error(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.predict( - prediction_service.PredictRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - -def test_explain(transport: str = "grpc"): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.ExplainRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse( - deployed_model_id="deployed_model_id_value", - ) - - response = client.explain(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.ExplainResponse) - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_explain_flattened(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.explain( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct.Value( - # null_value=struct.NullValue.NULL_VALUE - # ) - assert args[0].deployed_model_id == "deployed_model_id_value" - - -def test_explain_flattened_error(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.explain( - prediction_service.ExplainRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = PredictionServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.PredictionServiceGrpcTransport,) - - -def test_prediction_service_base_transport(): - # Instantiate the base transport. - transport = transports.PredictionServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "predict", - "explain", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - -def test_prediction_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - PredictionServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_prediction_service_host_no_port(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_prediction_service_host_with_port(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_prediction_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.PredictionServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel diff --git a/tests/unit/gapic/test_specialist_pool_service.py b/tests/unit/gapic/test_specialist_pool_service.py deleted file mode 100644 index f8467edf33..0000000000 --- a/tests/unit/gapic/test_specialist_pool_service.py +++ /dev/null @@ -1,681 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool_service -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore - - -def test_specialist_pool_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = SpecialistPoolServiceClient.from_service_account_file( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - client = SpecialistPoolServiceClient.from_service_account_json( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_specialist_pool_service_client_client_options(): - # Check the default options have their expected values. - assert ( - SpecialistPoolServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.SpecialistPoolServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = SpecialistPoolServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.SpecialistPoolServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = SpecialistPoolServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.CreateSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - - -def test_create_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_specialist_pool( - specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - -def test_get_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.GetSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", - specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], - ) - - response = client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] - - -def test_get_specialist_pool_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.GetSpecialistPoolRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - call.return_value = specialist_pool.SpecialistPool() - client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool.SpecialistPool() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", - ) - - -def test_list_specialist_pools(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.ListSpecialistPoolsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_specialist_pools_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.ListSpecialistPoolsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_specialist_pools_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_specialist_pools(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_specialist_pools_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", - ) - - -def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_specialist_pools(request={},)] - assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) - - -def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - pages = list(client.list_specialist_pools(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.DeleteSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", - ) - - -def test_update_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.UpdateSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.update_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_update_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_specialist_pool( - specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = SpecialistPoolServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance(client._transport, transports.SpecialistPoolServiceGrpcTransport,) - - -def test_specialist_pool_service_base_transport(): - # Instantiate the base transport. - transport = transports.SpecialistPoolServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_specialist_pool", - "get_specialist_pool", - "list_specialist_pools", - "delete_specialist_pool", - "update_specialist_pool", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_specialist_pool_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - SpecialistPoolServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_specialist_pool_service_host_no_port(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_specialist_pool_service_host_with_port(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.SpecialistPoolServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_specialist_pool_service_grpc_lro_client(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_specialist_pool_path(): - project = "squid" - location = "clam" - specialist_pool = "whelk" - - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) - actual = SpecialistPoolServiceClient.specialist_pool_path( - project, location, specialist_pool - ) - assert expected == actual From a507577075f9850a62cdbc63e329a0353071e1fb Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Mon, 19 Oct 2020 13:53:05 -0700 Subject: [PATCH 02/12] fix unit tests --- google/cloud/aiplatform_v1beta1/__init__.py | 338 +- .../services/dataset_service/__init__.py | 4 +- .../services/dataset_service/async_client.py | 439 +-- .../services/dataset_service/client.py | 539 ++- .../services/dataset_service/pagers.py | 102 +- .../dataset_service/transports/__init__.py | 10 +- .../dataset_service/transports/base.py | 223 +- .../dataset_service/transports/grpc.py | 231 +- .../transports/grpc_asyncio.py | 243 +- .../services/endpoint_service/__init__.py | 4 +- .../services/endpoint_service/async_client.py | 331 +- .../services/endpoint_service/client.py | 400 +- .../services/endpoint_service/pagers.py | 34 +- .../endpoint_service/transports/__init__.py | 10 +- .../endpoint_service/transports/base.py | 166 +- .../endpoint_service/transports/grpc.py | 182 +- .../transports/grpc_asyncio.py | 196 +- .../services/job_service/__init__.py | 4 +- .../services/job_service/async_client.py | 798 ++-- .../services/job_service/client.py | 941 +++-- .../services/job_service/pagers.py | 146 +- .../job_service/transports/__init__.py | 10 +- .../services/job_service/transports/base.py | 355 +- .../services/job_service/transports/grpc.py | 414 +- .../job_service/transports/grpc_asyncio.py | 432 ++- .../services/migration_service/__init__.py | 4 +- .../migration_service/async_client.py | 151 +- .../services/migration_service/client.py | 282 +- .../services/migration_service/pagers.py | 40 +- .../migration_service/transports/__init__.py | 10 +- .../migration_service/transports/base.py | 78 +- .../migration_service/transports/grpc.py | 115 +- .../transports/grpc_asyncio.py | 119 +- .../services/model_service/__init__.py | 4 +- .../services/model_service/async_client.py | 441 +-- .../services/model_service/client.py | 551 ++- .../services/model_service/pagers.py | 108 +- .../model_service/transports/__init__.py | 10 +- .../services/model_service/transports/base.py | 210 +- .../services/model_service/transports/grpc.py | 231 +- .../model_service/transports/grpc_asyncio.py | 239 +- .../services/pipeline_service/__init__.py | 4 +- .../services/pipeline_service/async_client.py | 253 +- .../services/pipeline_service/client.py | 329 +- .../services/pipeline_service/pagers.py | 40 +- .../pipeline_service/transports/__init__.py | 10 +- .../pipeline_service/transports/base.py | 124 +- .../pipeline_service/transports/grpc.py | 167 +- .../transports/grpc_asyncio.py | 173 +- .../services/prediction_service/__init__.py | 4 +- .../prediction_service/async_client.py | 152 +- .../services/prediction_service/client.py | 202 +- .../prediction_service/transports/__init__.py | 10 +- .../prediction_service/transports/base.py | 89 +- .../prediction_service/transports/grpc.py | 108 +- .../transports/grpc_asyncio.py | 111 +- .../specialist_pool_service/__init__.py | 4 +- .../specialist_pool_service/async_client.py | 264 +- .../specialist_pool_service/client.py | 305 +- .../specialist_pool_service/pagers.py | 40 +- .../transports/__init__.py | 14 +- .../transports/base.py | 121 +- .../transports/grpc.py | 164 +- .../transports/grpc_asyncio.py | 170 +- .../aiplatform_v1beta1/types/__init__.py | 546 ++- .../types/accelerator_type.py | 5 +- .../aiplatform_v1beta1/types/annotation.py | 21 +- .../types/annotation_spec.py | 13 +- .../types/batch_prediction_job.py | 99 +- .../types/completion_stats.py | 5 +- .../aiplatform_v1beta1/types/custom_job.py | 70 +- .../aiplatform_v1beta1/types/data_item.py | 17 +- .../types/data_labeling_job.py | 63 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 28 +- .../types/dataset_service.py | 102 +- .../types/deployed_model_ref.py | 5 +- .../aiplatform_v1beta1/types/endpoint.py | 36 +- .../types/endpoint_service.py | 68 +- .../cloud/aiplatform_v1beta1/types/env_var.py | 5 +- .../aiplatform_v1beta1/types/explanation.py | 38 +- .../types/explanation_metadata.py | 26 +- .../types/hyperparameter_tuning_job.py | 41 +- google/cloud/aiplatform_v1beta1/types/io.py | 12 +- .../aiplatform_v1beta1/types/job_service.py | 110 +- .../aiplatform_v1beta1/types/job_state.py | 5 +- .../types/machine_resources.py | 28 +- .../types/manual_batch_tuning_parameters.py | 6 +- .../types/migratable_resource.py | 37 +- .../types/migration_service.py | 70 +- .../cloud/aiplatform_v1beta1/types/model.py | 59 +- .../types/model_evaluation.py | 17 +- .../types/model_evaluation_slice.py | 18 +- .../aiplatform_v1beta1/types/model_service.py | 98 +- .../aiplatform_v1beta1/types/operation.py | 23 +- .../types/pipeline_service.py | 30 +- .../types/pipeline_state.py | 5 +- .../types/prediction_service.py | 34 +- .../types/specialist_pool.py | 5 +- .../types/specialist_pool_service.py | 46 +- .../cloud/aiplatform_v1beta1/types/study.py | 89 +- .../types/training_pipeline.py | 74 +- .../types/user_action_reference.py | 9 +- synth.metadata | 2 +- synth.py | 17 +- .../test_dataset_service.py | 2149 +++++----- .../test_endpoint_service.py | 1523 ++++---- .../aiplatform_v1beta1/test_job_service.py | 3445 ++++++++--------- .../test_migration_service.py | 913 +++-- .../aiplatform_v1beta1/test_model_service.py | 2241 +++++------ .../test_pipeline_service.py | 1215 +++--- .../test_prediction_service.py | 695 ++-- .../test_specialist_pool_service.py | 1082 +++--- 112 files changed, 13565 insertions(+), 13683 deletions(-) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index da76eaf689..7d45ebe371 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -187,173 +187,173 @@ __all__ = ( - 'AcceleratorType', - 'ActiveLearningConfig', - 'Annotation', - 'AnnotationSpec', - 'Attribution', - 'AutomaticResources', - 'BatchDedicatedResources', - 'BatchMigrateResourcesOperationMetadata', - 'BatchMigrateResourcesRequest', - 'BatchMigrateResourcesResponse', - 'BatchPredictionJob', - 'BigQueryDestination', - 'BigQuerySource', - 'CancelBatchPredictionJobRequest', - 'CancelCustomJobRequest', - 'CancelDataLabelingJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CancelTrainingPipelineRequest', - 'CompletionStats', - 'ContainerRegistryDestination', - 'ContainerSpec', - 'CreateBatchPredictionJobRequest', - 'CreateCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'CreateDatasetOperationMetadata', - 'CreateDatasetRequest', - 'CreateEndpointOperationMetadata', - 'CreateEndpointRequest', - 'CreateHyperparameterTuningJobRequest', - 'CreateSpecialistPoolOperationMetadata', - 'CreateSpecialistPoolRequest', - 'CreateTrainingPipelineRequest', - 'CustomJob', - 'CustomJobSpec', - 'DataItem', - 'DataLabelingJob', - 'Dataset', - 'DedicatedResources', - 'DeleteBatchPredictionJobRequest', - 'DeleteCustomJobRequest', - 'DeleteDataLabelingJobRequest', - 'DeleteDatasetRequest', - 'DeleteEndpointRequest', - 'DeleteHyperparameterTuningJobRequest', - 'DeleteModelRequest', - 'DeleteOperationMetadata', - 'DeleteSpecialistPoolRequest', - 'DeleteTrainingPipelineRequest', - 'DeployModelOperationMetadata', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployedModel', - 'DeployedModelRef', - 'Endpoint', - 'EndpointServiceClient', - 'EnvVar', - 'ExplainRequest', - 'ExplainResponse', - 'Explanation', - 'ExplanationMetadata', - 'ExplanationParameters', - 'ExplanationSpec', - 'ExportDataConfig', - 'ExportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportModelOperationMetadata', - 'ExportModelRequest', - 'ExportModelResponse', - 'FilterSplit', - 'FractionSplit', - 'GcsDestination', - 'GcsSource', - 'GenericOperationMetadata', - 'GetAnnotationSpecRequest', - 'GetBatchPredictionJobRequest', - 'GetCustomJobRequest', - 'GetDataLabelingJobRequest', - 'GetDatasetRequest', - 'GetEndpointRequest', - 'GetHyperparameterTuningJobRequest', - 'GetModelEvaluationRequest', - 'GetModelEvaluationSliceRequest', - 'GetModelRequest', - 'GetSpecialistPoolRequest', - 'GetTrainingPipelineRequest', - 'HyperparameterTuningJob', - 'ImportDataConfig', - 'ImportDataOperationMetadata', - 'ImportDataRequest', - 'ImportDataResponse', - 'InputDataConfig', - 'JobServiceClient', - 'JobState', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'ListModelsRequest', - 'ListModelsResponse', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'MachineSpec', - 'ManualBatchTuningParameters', - 'Measurement', - 'MigratableResource', - 'MigrateResourceRequest', - 'MigrateResourceResponse', - 'MigrationServiceClient', - 'Model', - 'ModelContainerSpec', - 'ModelEvaluation', - 'ModelEvaluationSlice', - 'ModelExplanation', - 'ModelServiceClient', - 'PipelineServiceClient', - 'PipelineState', - 'Port', - 'PredefinedSplit', - 'PredictRequest', - 'PredictResponse', - 'PredictSchemata', - 'PredictionServiceClient', - 'PythonPackageSpec', - 'ResourcesConsumed', - 'SampleConfig', - 'SampledShapleyAttribution', - 'Scheduling', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'SpecialistPool', - 'SpecialistPoolServiceClient', - 'StudySpec', - 'TimestampSplit', - 'TrainingConfig', - 'TrainingPipeline', - 'Trial', - 'UndeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UpdateDatasetRequest', - 'UpdateEndpointRequest', - 'UpdateModelRequest', - 'UpdateSpecialistPoolOperationMetadata', - 'UpdateSpecialistPoolRequest', - 'UploadModelOperationMetadata', - 'UploadModelRequest', - 'UploadModelResponse', - 'UserActionReference', - 'WorkerPoolSpec', -'DatasetServiceClient', + "AcceleratorType", + "ActiveLearningConfig", + "Annotation", + "AnnotationSpec", + "Attribution", + "AutomaticResources", + "BatchDedicatedResources", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "BatchPredictionJob", + "BigQueryDestination", + "BigQuerySource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CancelTrainingPipelineRequest", + "CompletionStats", + "ContainerRegistryDestination", + "ContainerSpec", + "CreateBatchPredictionJobRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "CreateHyperparameterTuningJobRequest", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "CreateTrainingPipelineRequest", + "CustomJob", + "CustomJobSpec", + "DataItem", + "DataLabelingJob", + "Dataset", + "DedicatedResources", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteDatasetRequest", + "DeleteEndpointRequest", + "DeleteHyperparameterTuningJobRequest", + "DeleteModelRequest", + "DeleteOperationMetadata", + "DeleteSpecialistPoolRequest", + "DeleteTrainingPipelineRequest", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "DeployedModel", + "DeployedModelRef", + "Endpoint", + "EndpointServiceClient", + "EnvVar", + "ExplainRequest", + "ExplainResponse", + "Explanation", + "ExplanationMetadata", + "ExplanationParameters", + "ExplanationSpec", + "ExportDataConfig", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "FilterSplit", + "FractionSplit", + "GcsDestination", + "GcsSource", + "GenericOperationMetadata", + "GetAnnotationSpecRequest", + "GetBatchPredictionJobRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetDatasetRequest", + "GetEndpointRequest", + "GetHyperparameterTuningJobRequest", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "GetSpecialistPoolRequest", + "GetTrainingPipelineRequest", + "HyperparameterTuningJob", + "ImportDataConfig", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "InputDataConfig", + "JobServiceClient", + "JobState", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "ListEndpointsRequest", + "ListEndpointsResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "MachineSpec", + "ManualBatchTuningParameters", + "Measurement", + "MigratableResource", + "MigrateResourceRequest", + "MigrateResourceResponse", + "MigrationServiceClient", + "Model", + "ModelContainerSpec", + "ModelEvaluation", + "ModelEvaluationSlice", + "ModelExplanation", + "ModelServiceClient", + "PipelineServiceClient", + "PipelineState", + "Port", + "PredefinedSplit", + "PredictRequest", + "PredictResponse", + "PredictSchemata", + "PredictionServiceClient", + "PythonPackageSpec", + "ResourcesConsumed", + "SampleConfig", + "SampledShapleyAttribution", + "Scheduling", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "SpecialistPool", + "SpecialistPoolServiceClient", + "StudySpec", + "TimestampSplit", + "TrainingConfig", + "TrainingPipeline", + "Trial", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateDatasetRequest", + "UpdateEndpointRequest", + "UpdateModelRequest", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "UserActionReference", + "WorkerPoolSpec", + "DatasetServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py index 9d1f004f6a..597f654cb9 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import DatasetServiceAsyncClient __all__ = ( - 'DatasetServiceClient', - 'DatasetServiceAsyncClient', + "DatasetServiceClient", + "DatasetServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 8b67c83c6a..984683b4ac 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -59,26 +59,42 @@ class DatasetServiceAsyncClient: annotation_path = staticmethod(DatasetServiceClient.annotation_path) parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) - parse_annotation_spec_path = staticmethod(DatasetServiceClient.parse_annotation_spec_path) + parse_annotation_spec_path = staticmethod( + DatasetServiceClient.parse_annotation_spec_path + ) data_item_path = staticmethod(DatasetServiceClient.data_item_path) parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) dataset_path = staticmethod(DatasetServiceClient.dataset_path) parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) - common_billing_account_path = staticmethod(DatasetServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(DatasetServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + DatasetServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + DatasetServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(DatasetServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + DatasetServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(DatasetServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(DatasetServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + DatasetServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + DatasetServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(DatasetServiceClient.common_project_path) - parse_common_project_path = staticmethod(DatasetServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + DatasetServiceClient.parse_common_project_path + ) common_location_path = staticmethod(DatasetServiceClient.common_location_path) - parse_common_location_path = staticmethod(DatasetServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + DatasetServiceClient.parse_common_location_path + ) from_service_account_file = DatasetServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -92,14 +108,18 @@ def transport(self) -> DatasetServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient)) + get_transport_class = functools.partial( + type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -138,18 +158,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a Dataset. Args: @@ -188,8 +208,10 @@ async def create_dataset(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, dataset]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.CreateDatasetRequest(request) @@ -212,18 +234,11 @@ async def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -236,14 +251,15 @@ async def create_dataset(self, # Done; return the response. return response - async def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + async def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -273,8 +289,10 @@ async def get_dataset(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.GetDatasetRequest(request) @@ -295,31 +313,25 @@ async def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + async def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -362,8 +374,10 @@ async def update_dataset(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([dataset, update_mask]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.UpdateDatasetRequest(request) @@ -386,30 +400,26 @@ async def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsAsyncPager: + async def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: r"""Lists Datasets in a Location. Args: @@ -442,8 +452,10 @@ async def list_datasets(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListDatasetsRequest(request) @@ -464,39 +476,30 @@ async def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDatasetsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Dataset. Args: @@ -542,8 +545,10 @@ async def delete_dataset(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.DeleteDatasetRequest(request) @@ -564,18 +569,11 @@ async def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -588,15 +586,16 @@ async def delete_dataset(self, # Done; return the response. return response - async def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Imports data into a Dataset. Args: @@ -637,8 +636,10 @@ async def import_data(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name, import_configs]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ImportDataRequest(request) @@ -661,18 +662,11 @@ async def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -685,15 +679,16 @@ async def import_data(self, # Done; return the response. return response - async def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports data from a Dataset. Args: @@ -733,8 +728,10 @@ async def export_data(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name, export_config]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ExportDataRequest(request) @@ -757,18 +754,11 @@ async def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -781,14 +771,15 @@ async def export_data(self, # Done; return the response. return response - async def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsAsyncPager: + async def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: r"""Lists DataItems in a Dataset. Args: @@ -822,8 +813,10 @@ async def list_data_items(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListDataItemsRequest(request) @@ -844,39 +837,30 @@ async def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataItemsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + async def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -908,8 +892,10 @@ async def get_annotation_spec(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.GetAnnotationSpecRequest(request) @@ -930,30 +916,24 @@ async def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsAsyncPager: + async def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: r"""Lists Annotations belongs to a dataitem Args: @@ -988,8 +968,10 @@ async def list_annotations(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = dataset_service.ListAnnotationsRequest(request) @@ -1010,47 +992,30 @@ async def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListAnnotationsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceAsyncClient', -) +__all__ = ("DatasetServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index b60c70f7c9..1e63153291 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -59,13 +59,14 @@ class DatasetServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] - _transport_registry['grpc'] = DatasetServiceGrpcTransport - _transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[DatasetServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry["grpc"] = DatasetServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. Args: @@ -116,7 +117,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -135,9 +136,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -152,110 +152,149 @@ def transport(self) -> DatasetServiceTransport: return self._transport @staticmethod - def annotation_path(project: str,location: str,dataset: str,data_item: str,annotation: str,) -> str: + def annotation_path( + project: str, location: str, dataset: str, data_item: str, annotation: str, + ) -> str: """Return a fully-qualified annotation string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) @staticmethod - def parse_annotation_path(path: str) -> Dict[str,str]: + def parse_annotation_path(path: str) -> Dict[str, str]: """Parse a annotation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def annotation_spec_path(project: str,location: str,dataset: str,annotation_spec: str,) -> str: + def annotation_spec_path( + project: str, location: str, dataset: str, annotation_spec: str, + ) -> str: """Return a fully-qualified annotation_spec string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) @staticmethod - def parse_annotation_spec_path(path: str) -> Dict[str,str]: + def parse_annotation_spec_path(path: str) -> Dict[str, str]: """Parse a annotation_spec path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_item_path(project: str,location: str,dataset: str,data_item: str,) -> str: + def data_item_path( + project: str, location: str, dataset: str, data_item: str, + ) -> str: """Return a fully-qualified data_item string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) @staticmethod - def parse_data_item_path(path: str) -> Dict[str,str]: + def parse_data_item_path(path: str) -> Dict[str, str]: """Parse a data_item path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the dataset service client. Args: @@ -299,7 +338,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -327,7 +368,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -339,8 +382,10 @@ def __init__(self, *, if isinstance(transport, DatasetServiceTransport): # transport is a DatasetServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -359,15 +404,16 @@ def __init__(self, *, client_info=client_info, ) - def create_dataset(self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a Dataset. Args: @@ -407,8 +453,10 @@ def create_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, dataset]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.CreateDatasetRequest. @@ -432,18 +480,11 @@ def create_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -456,14 +497,15 @@ def create_dataset(self, # Done; return the response. return response - def get_dataset(self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: + def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: r"""Gets a Dataset. Args: @@ -494,8 +536,10 @@ def get_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetDatasetRequest. @@ -517,31 +561,25 @@ def get_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def update_dataset(self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: + def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: r"""Updates a Dataset. Args: @@ -585,8 +623,10 @@ def update_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([dataset, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.UpdateDatasetRequest. @@ -610,30 +650,26 @@ def update_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('dataset.name', request.dataset.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_datasets(self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsPager: + def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: r"""Lists Datasets in a Location. Args: @@ -667,8 +703,10 @@ def list_datasets(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDatasetsRequest. @@ -690,39 +728,30 @@ def list_datasets(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_dataset(self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Dataset. Args: @@ -769,8 +798,10 @@ def delete_dataset(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.DeleteDatasetRequest. @@ -792,18 +823,11 @@ def delete_dataset(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -816,15 +840,16 @@ def delete_dataset(self, # Done; return the response. return response - def import_data(self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Imports data into a Dataset. Args: @@ -866,8 +891,10 @@ def import_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, import_configs]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ImportDataRequest. @@ -892,18 +919,11 @@ def import_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -916,15 +936,16 @@ def import_data(self, # Done; return the response. return response - def export_data(self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports data from a Dataset. Args: @@ -965,8 +986,10 @@ def export_data(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, export_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ExportDataRequest. @@ -990,18 +1013,11 @@ def export_data(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1014,14 +1030,15 @@ def export_data(self, # Done; return the response. return response - def list_data_items(self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsPager: + def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: r"""Lists DataItems in a Dataset. Args: @@ -1056,8 +1073,10 @@ def list_data_items(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListDataItemsRequest. @@ -1079,39 +1098,30 @@ def list_data_items(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_annotation_spec(self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: + def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: r"""Gets an AnnotationSpec. Args: @@ -1144,8 +1154,10 @@ def get_annotation_spec(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.GetAnnotationSpecRequest. @@ -1167,30 +1179,24 @@ def get_annotation_spec(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_annotations(self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsPager: + def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: r"""Lists Annotations belongs to a dataitem Args: @@ -1226,8 +1232,10 @@ def list_annotations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a dataset_service.ListAnnotationsRequest. @@ -1249,47 +1257,30 @@ def list_annotations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'DatasetServiceClient', -) +__all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index af29515c1d..43c3156caf 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -40,12 +40,15 @@ class ListDatasetsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListDatasetsResponse], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListDatasetsResponse], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -79,7 +82,7 @@ def __iter__(self) -> Iterable[dataset.Dataset]: yield from page.datasets def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDatasetsAsyncPager: @@ -99,12 +102,15 @@ class ListDatasetsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -142,7 +148,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataItemsPager: @@ -162,12 +168,15 @@ class ListDataItemsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListDataItemsResponse], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListDataItemsResponse], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -201,7 +210,7 @@ def __iter__(self) -> Iterable[data_item.DataItem]: yield from page.data_items def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataItemsAsyncPager: @@ -221,12 +230,15 @@ class ListDataItemsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -264,7 +276,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListAnnotationsPager: @@ -284,12 +296,15 @@ class ListAnnotationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., dataset_service.ListAnnotationsResponse], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., dataset_service.ListAnnotationsResponse], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -323,7 +338,7 @@ def __iter__(self) -> Iterable[annotation.Annotation]: yield from page.annotations def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListAnnotationsAsyncPager: @@ -343,12 +358,15 @@ class ListAnnotationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -386,4 +404,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py index fd4e511640..f8496b801c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py @@ -25,12 +25,12 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] -_transport_registry['grpc'] = DatasetServiceGrpcTransport -_transport_registry['grpc_asyncio'] = DatasetServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = DatasetServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport __all__ = ( - 'DatasetServiceTransport', - 'DatasetServiceGrpcTransport', - 'DatasetServiceGrpcAsyncIOTransport', + "DatasetServiceTransport", + "DatasetServiceGrpcTransport", + "DatasetServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index 1fa9766314..8cceeb197c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -36,29 +36,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,24 +81,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -110,56 +112,35 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_dataset: gapic_v1.method.wrap_method( - self.create_dataset, - default_timeout=None, - client_info=client_info, + self.create_dataset, default_timeout=None, client_info=client_info, ), self.get_dataset: gapic_v1.method.wrap_method( - self.get_dataset, - default_timeout=None, - client_info=client_info, + self.get_dataset, default_timeout=None, client_info=client_info, ), self.update_dataset: gapic_v1.method.wrap_method( - self.update_dataset, - default_timeout=None, - client_info=client_info, + self.update_dataset, default_timeout=None, client_info=client_info, ), self.list_datasets: gapic_v1.method.wrap_method( - self.list_datasets, - default_timeout=None, - client_info=client_info, + self.list_datasets, default_timeout=None, client_info=client_info, ), self.delete_dataset: gapic_v1.method.wrap_method( - self.delete_dataset, - default_timeout=None, - client_info=client_info, + self.delete_dataset, default_timeout=None, client_info=client_info, ), self.import_data: gapic_v1.method.wrap_method( - self.import_data, - default_timeout=None, - client_info=client_info, + self.import_data, default_timeout=None, client_info=client_info, ), self.export_data: gapic_v1.method.wrap_method( - self.export_data, - default_timeout=None, - client_info=client_info, + self.export_data, default_timeout=None, client_info=client_info, ), self.list_data_items: gapic_v1.method.wrap_method( - self.list_data_items, - default_timeout=None, - client_info=client_info, + self.list_data_items, default_timeout=None, client_info=client_info, ), self.get_annotation_spec: gapic_v1.method.wrap_method( - self.get_annotation_spec, - default_timeout=None, - client_info=client_info, + self.get_annotation_spec, default_timeout=None, client_info=client_info, ), self.list_annotations: gapic_v1.method.wrap_method( - self.list_annotations, - default_timeout=None, - client_info=client_info, + self.list_annotations, default_timeout=None, client_info=client_info, ), - } @property @@ -168,96 +149,106 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_dataset(self) -> typing.Callable[ - [dataset_service.CreateDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_dataset( + self, + ) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_dataset(self) -> typing.Callable[ - [dataset_service.GetDatasetRequest], - typing.Union[ - dataset.Dataset, - typing.Awaitable[dataset.Dataset] - ]]: + def get_dataset( + self, + ) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], + ]: raise NotImplementedError() @property - def update_dataset(self) -> typing.Callable[ - [dataset_service.UpdateDatasetRequest], - typing.Union[ - gca_dataset.Dataset, - typing.Awaitable[gca_dataset.Dataset] - ]]: + def update_dataset( + self, + ) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], + ]: raise NotImplementedError() @property - def list_datasets(self) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], - typing.Union[ - dataset_service.ListDatasetsResponse, - typing.Awaitable[dataset_service.ListDatasetsResponse] - ]]: + def list_datasets( + self, + ) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse], + ], + ]: raise NotImplementedError() @property - def delete_dataset(self) -> typing.Callable[ - [dataset_service.DeleteDatasetRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_dataset( + self, + ) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def import_data(self) -> typing.Callable[ - [dataset_service.ImportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def import_data( + self, + ) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_data(self) -> typing.Callable[ - [dataset_service.ExportDataRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_data( + self, + ) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def list_data_items(self) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], - typing.Union[ - dataset_service.ListDataItemsResponse, - typing.Awaitable[dataset_service.ListDataItemsResponse] - ]]: + def list_data_items( + self, + ) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse], + ], + ]: raise NotImplementedError() @property - def get_annotation_spec(self) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], - typing.Union[ - annotation_spec.AnnotationSpec, - typing.Awaitable[annotation_spec.AnnotationSpec] - ]]: + def get_annotation_spec( + self, + ) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec], + ], + ]: raise NotImplementedError() @property - def list_annotations(self) -> typing.Callable[ - [dataset_service.ListAnnotationsRequest], - typing.Union[ - dataset_service.ListAnnotationsResponse, - typing.Awaitable[dataset_service.ListAnnotationsResponse] - ]]: + def list_annotations( + self, + ) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'DatasetServiceTransport', -) +__all__ = ("DatasetServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 3914e0d35b..801346ca58 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -46,20 +46,23 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -109,12 +112,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -139,7 +151,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -164,13 +178,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -203,7 +219,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -220,18 +236,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_dataset(self) -> Callable[ - [dataset_service.CreateDatasetRequest], - operations.Operation]: + def create_dataset( + self, + ) -> Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -246,18 +262,18 @@ def create_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_dataset' not in self._stubs: - self._stubs['create_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_dataset'] + return self._stubs["create_dataset"] @property - def get_dataset(self) -> Callable[ - [dataset_service.GetDatasetRequest], - dataset.Dataset]: + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -272,18 +288,18 @@ def get_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_dataset' not in self._stubs: - self._stubs['get_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs['get_dataset'] + return self._stubs["get_dataset"] @property - def update_dataset(self) -> Callable[ - [dataset_service.UpdateDatasetRequest], - gca_dataset.Dataset]: + def update_dataset( + self, + ) -> Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -298,18 +314,20 @@ def update_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_dataset' not in self._stubs: - self._stubs['update_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs['update_dataset'] + return self._stubs["update_dataset"] @property - def list_datasets(self) -> Callable[ - [dataset_service.ListDatasetsRequest], - dataset_service.ListDatasetsResponse]: + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse + ]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -324,18 +342,18 @@ def list_datasets(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_datasets' not in self._stubs: - self._stubs['list_datasets'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs['list_datasets'] + return self._stubs["list_datasets"] @property - def delete_dataset(self) -> Callable[ - [dataset_service.DeleteDatasetRequest], - operations.Operation]: + def delete_dataset( + self, + ) -> Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -350,18 +368,18 @@ def delete_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_dataset' not in self._stubs: - self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_dataset'] + return self._stubs["delete_dataset"] @property - def import_data(self) -> Callable[ - [dataset_service.ImportDataRequest], - operations.Operation]: + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], operations.Operation]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -376,18 +394,18 @@ def import_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_data' not in self._stubs: - self._stubs['import_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_data'] + return self._stubs["import_data"] @property - def export_data(self) -> Callable[ - [dataset_service.ExportDataRequest], - operations.Operation]: + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], operations.Operation]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -402,18 +420,20 @@ def export_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_data' not in self._stubs: - self._stubs['export_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_data'] + return self._stubs["export_data"] @property - def list_data_items(self) -> Callable[ - [dataset_service.ListDataItemsRequest], - dataset_service.ListDataItemsResponse]: + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse + ]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -428,18 +448,20 @@ def list_data_items(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_items' not in self._stubs: - self._stubs['list_data_items'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs['list_data_items'] + return self._stubs["list_data_items"] @property - def get_annotation_spec(self) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - annotation_spec.AnnotationSpec]: + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec + ]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -454,18 +476,21 @@ def get_annotation_spec(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_annotation_spec' not in self._stubs: - self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs['get_annotation_spec'] + return self._stubs["get_annotation_spec"] @property - def list_annotations(self) -> Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse]: + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse, + ]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -480,15 +505,13 @@ def list_annotations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_annotations' not in self._stubs: - self._stubs['list_annotations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs['list_annotations'] + return self._stubs["list_annotations"] -__all__ = ( - 'DatasetServiceGrpcTransport', -) +__all__ = ("DatasetServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index c8d51ca917..c0067cb997 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import annotation_spec @@ -53,13 +53,15 @@ class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -88,21 +90,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -153,12 +157,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -183,7 +196,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -225,18 +240,20 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_dataset(self) -> Callable[ - [dataset_service.CreateDatasetRequest], - Awaitable[operations.Operation]]: + def create_dataset( + self, + ) -> Callable[ + [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create dataset method over gRPC. Creates a Dataset. @@ -251,18 +268,18 @@ def create_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_dataset' not in self._stubs: - self._stubs['create_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset', + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", request_serializer=dataset_service.CreateDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_dataset'] + return self._stubs["create_dataset"] @property - def get_dataset(self) -> Callable[ - [dataset_service.GetDatasetRequest], - Awaitable[dataset.Dataset]]: + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: r"""Return a callable for the get dataset method over gRPC. Gets a Dataset. @@ -277,18 +294,20 @@ def get_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_dataset' not in self._stubs: - self._stubs['get_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset', + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", request_serializer=dataset_service.GetDatasetRequest.serialize, response_deserializer=dataset.Dataset.deserialize, ) - return self._stubs['get_dataset'] + return self._stubs["get_dataset"] @property - def update_dataset(self) -> Callable[ - [dataset_service.UpdateDatasetRequest], - Awaitable[gca_dataset.Dataset]]: + def update_dataset( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] + ]: r"""Return a callable for the update dataset method over gRPC. Updates a Dataset. @@ -303,18 +322,21 @@ def update_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_dataset' not in self._stubs: - self._stubs['update_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset', + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", request_serializer=dataset_service.UpdateDatasetRequest.serialize, response_deserializer=gca_dataset.Dataset.deserialize, ) - return self._stubs['update_dataset'] + return self._stubs["update_dataset"] @property - def list_datasets(self) -> Callable[ - [dataset_service.ListDatasetsRequest], - Awaitable[dataset_service.ListDatasetsResponse]]: + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse], + ]: r"""Return a callable for the list datasets method over gRPC. Lists Datasets in a Location. @@ -329,18 +351,20 @@ def list_datasets(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_datasets' not in self._stubs: - self._stubs['list_datasets'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets', + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", request_serializer=dataset_service.ListDatasetsRequest.serialize, response_deserializer=dataset_service.ListDatasetsResponse.deserialize, ) - return self._stubs['list_datasets'] + return self._stubs["list_datasets"] @property - def delete_dataset(self) -> Callable[ - [dataset_service.DeleteDatasetRequest], - Awaitable[operations.Operation]]: + def delete_dataset( + self, + ) -> Callable[ + [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete dataset method over gRPC. Deletes a Dataset. @@ -355,18 +379,18 @@ def delete_dataset(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_dataset' not in self._stubs: - self._stubs['delete_dataset'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset', + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", request_serializer=dataset_service.DeleteDatasetRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_dataset'] + return self._stubs["delete_dataset"] @property - def import_data(self) -> Callable[ - [dataset_service.ImportDataRequest], - Awaitable[operations.Operation]]: + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: r"""Return a callable for the import data method over gRPC. Imports data into a Dataset. @@ -381,18 +405,18 @@ def import_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'import_data' not in self._stubs: - self._stubs['import_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ImportData', + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", request_serializer=dataset_service.ImportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['import_data'] + return self._stubs["import_data"] @property - def export_data(self) -> Callable[ - [dataset_service.ExportDataRequest], - Awaitable[operations.Operation]]: + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: r"""Return a callable for the export data method over gRPC. Exports data from a Dataset. @@ -407,18 +431,21 @@ def export_data(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_data' not in self._stubs: - self._stubs['export_data'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ExportData', + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", request_serializer=dataset_service.ExportDataRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_data'] + return self._stubs["export_data"] @property - def list_data_items(self) -> Callable[ - [dataset_service.ListDataItemsRequest], - Awaitable[dataset_service.ListDataItemsResponse]]: + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse], + ]: r"""Return a callable for the list data items method over gRPC. Lists DataItems in a Dataset. @@ -433,18 +460,21 @@ def list_data_items(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_items' not in self._stubs: - self._stubs['list_data_items'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems', + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", request_serializer=dataset_service.ListDataItemsRequest.serialize, response_deserializer=dataset_service.ListDataItemsResponse.deserialize, ) - return self._stubs['list_data_items'] + return self._stubs["list_data_items"] @property - def get_annotation_spec(self) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - Awaitable[annotation_spec.AnnotationSpec]]: + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec], + ]: r"""Return a callable for the get annotation spec method over gRPC. Gets an AnnotationSpec. @@ -459,18 +489,21 @@ def get_annotation_spec(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_annotation_spec' not in self._stubs: - self._stubs['get_annotation_spec'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec', + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, response_deserializer=annotation_spec.AnnotationSpec.deserialize, ) - return self._stubs['get_annotation_spec'] + return self._stubs["get_annotation_spec"] @property - def list_annotations(self) -> Callable[ - [dataset_service.ListAnnotationsRequest], - Awaitable[dataset_service.ListAnnotationsResponse]]: + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse], + ]: r"""Return a callable for the list annotations method over gRPC. Lists Annotations belongs to a dataitem @@ -485,15 +518,13 @@ def list_annotations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_annotations' not in self._stubs: - self._stubs['list_annotations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations', + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", request_serializer=dataset_service.ListAnnotationsRequest.serialize, response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, ) - return self._stubs['list_annotations'] + return self._stubs["list_annotations"] -__all__ = ( - 'DatasetServiceGrpcAsyncIOTransport', -) +__all__ = ("DatasetServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py index e4f3dcfbcf..035a5b2388 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import EndpointServiceAsyncClient __all__ = ( - 'EndpointServiceClient', - 'EndpointServiceAsyncClient', + "EndpointServiceClient", + "EndpointServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 972cb90855..3801c42a08 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -57,20 +57,34 @@ class EndpointServiceAsyncClient: model_path = staticmethod(EndpointServiceClient.model_path) parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) - common_billing_account_path = staticmethod(EndpointServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(EndpointServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + EndpointServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + EndpointServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(EndpointServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + EndpointServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(EndpointServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(EndpointServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + EndpointServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + EndpointServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(EndpointServiceClient.common_project_path) - parse_common_project_path = staticmethod(EndpointServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + EndpointServiceClient.parse_common_project_path + ) common_location_path = staticmethod(EndpointServiceClient.common_location_path) - parse_common_location_path = staticmethod(EndpointServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + EndpointServiceClient.parse_common_location_path + ) from_service_account_file = EndpointServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -84,14 +98,18 @@ def transport(self) -> EndpointServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient)) + get_transport_class = functools.partial( + type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -130,18 +148,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates an Endpoint. Args: @@ -181,8 +199,10 @@ async def create_endpoint(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, endpoint]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.CreateEndpointRequest(request) @@ -205,18 +225,11 @@ async def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -229,14 +242,15 @@ async def create_endpoint(self, # Done; return the response. return response - async def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + async def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -267,8 +281,10 @@ async def get_endpoint(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.GetEndpointRequest(request) @@ -289,30 +305,24 @@ async def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsAsyncPager: + async def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: r"""Lists Endpoints in a Location. Args: @@ -346,8 +356,10 @@ async def list_endpoints(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.ListEndpointsRequest(request) @@ -368,40 +380,31 @@ async def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListEndpointsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + async def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -438,8 +441,10 @@ async def update_endpoint(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([endpoint, update_mask]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.UpdateEndpointRequest(request) @@ -462,30 +467,26 @@ async def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes an Endpoint. Args: @@ -531,8 +532,10 @@ async def delete_endpoint(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.DeleteEndpointRequest(request) @@ -553,18 +556,11 @@ async def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -577,16 +573,19 @@ async def delete_endpoint(self, # Done; return the response. return response - async def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -651,8 +650,10 @@ async def deploy_model(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([endpoint, deployed_model, traffic_split]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.DeployModelRequest(request) @@ -677,18 +678,11 @@ async def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -701,16 +695,19 @@ async def deploy_model(self, # Done; return the response. return response - async def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -766,8 +763,10 @@ async def undeploy_model(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([endpoint, deployed_model_id, traffic_split]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = endpoint_service.UndeployModelRequest(request) @@ -792,18 +791,11 @@ async def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -817,21 +809,14 @@ async def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceAsyncClient', -) +__all__ = ("EndpointServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index f601c6f145..fbbe7219e4 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -55,13 +55,14 @@ class EndpointServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] - _transport_registry['grpc'] = EndpointServiceGrpcTransport - _transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[EndpointServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry["grpc"] = EndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. Args: @@ -112,7 +113,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -131,9 +132,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -148,88 +148,104 @@ def transport(self) -> EndpointServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the endpoint service client. Args: @@ -273,7 +289,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -301,7 +319,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -313,8 +333,10 @@ def __init__(self, *, if isinstance(transport, EndpointServiceTransport): # transport is a EndpointServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -333,15 +355,16 @@ def __init__(self, *, client_info=client_info, ) - def create_endpoint(self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates an Endpoint. Args: @@ -382,8 +405,10 @@ def create_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, endpoint]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.CreateEndpointRequest. @@ -407,18 +432,11 @@ def create_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -431,14 +449,15 @@ def create_endpoint(self, # Done; return the response. return response - def get_endpoint(self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: + def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: r"""Gets an Endpoint. Args: @@ -470,8 +489,10 @@ def get_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.GetEndpointRequest. @@ -493,30 +514,24 @@ def get_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_endpoints(self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsPager: + def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: r"""Lists Endpoints in a Location. Args: @@ -551,8 +566,10 @@ def list_endpoints(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.ListEndpointsRequest. @@ -574,40 +591,31 @@ def list_endpoints(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_endpoint(self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: + def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: r"""Updates an Endpoint. Args: @@ -645,8 +653,10 @@ def update_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UpdateEndpointRequest. @@ -670,30 +680,26 @@ def update_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint.name', request.endpoint.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_endpoint(self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes an Endpoint. Args: @@ -740,8 +746,10 @@ def delete_endpoint(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeleteEndpointRequest. @@ -763,18 +771,11 @@ def delete_endpoint(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -787,16 +788,19 @@ def delete_endpoint(self, # Done; return the response. return response - def deploy_model(self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[endpoint_service.DeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deploys a Model into this Endpoint, creating a DeployedModel within it. @@ -862,8 +866,10 @@ def deploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.DeployModelRequest. @@ -881,7 +887,7 @@ def deploy_model(self, request.deployed_model = deployed_model if traffic_split: - request.traffic_split.extend(traffic_split) + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -890,18 +896,11 @@ def deploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -914,16 +913,19 @@ def deploy_model(self, # Done; return the response. return response - def undeploy_model(self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[endpoint_service.UndeployModelRequest.TrafficSplitEntry] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Undeploys a Model from an Endpoint, removing a DeployedModel from it, and freeing all resources it's using. @@ -980,8 +982,10 @@ def undeploy_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a endpoint_service.UndeployModelRequest. @@ -999,7 +1003,7 @@ def undeploy_model(self, request.deployed_model_id = deployed_model_id if traffic_split: - request.traffic_split.extend(traffic_split) + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -1008,18 +1012,11 @@ def undeploy_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1033,21 +1030,14 @@ def undeploy_model(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'EndpointServiceClient', -) +__all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 50399b1826..86320c2178 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -38,12 +38,15 @@ class ListEndpointsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., endpoint_service.ListEndpointsResponse], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., endpoint_service.ListEndpointsResponse], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +80,7 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: yield from page.endpoints def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListEndpointsAsyncPager: @@ -97,12 +100,15 @@ class ListEndpointsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -140,4 +146,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py index fea1a635d6..70a87e920e 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py @@ -25,12 +25,12 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] -_transport_registry['grpc'] = EndpointServiceGrpcTransport -_transport_registry['grpc_asyncio'] = EndpointServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = EndpointServiceGrpcTransport +_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport __all__ = ( - 'EndpointServiceTransport', - 'EndpointServiceGrpcTransport', - 'EndpointServiceGrpcAsyncIOTransport', + "EndpointServiceTransport", + "EndpointServiceGrpcTransport", + "EndpointServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index cb5e891416..63965464b7 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,29 +35,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -80,24 +80,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -109,41 +111,26 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_endpoint: gapic_v1.method.wrap_method( - self.create_endpoint, - default_timeout=None, - client_info=client_info, + self.create_endpoint, default_timeout=None, client_info=client_info, ), self.get_endpoint: gapic_v1.method.wrap_method( - self.get_endpoint, - default_timeout=None, - client_info=client_info, + self.get_endpoint, default_timeout=None, client_info=client_info, ), self.list_endpoints: gapic_v1.method.wrap_method( - self.list_endpoints, - default_timeout=None, - client_info=client_info, + self.list_endpoints, default_timeout=None, client_info=client_info, ), self.update_endpoint: gapic_v1.method.wrap_method( - self.update_endpoint, - default_timeout=None, - client_info=client_info, + self.update_endpoint, default_timeout=None, client_info=client_info, ), self.delete_endpoint: gapic_v1.method.wrap_method( - self.delete_endpoint, - default_timeout=None, - client_info=client_info, + self.delete_endpoint, default_timeout=None, client_info=client_info, ), self.deploy_model: gapic_v1.method.wrap_method( - self.deploy_model, - default_timeout=None, - client_info=client_info, + self.deploy_model, default_timeout=None, client_info=client_info, ), self.undeploy_model: gapic_v1.method.wrap_method( - self.undeploy_model, - default_timeout=None, - client_info=client_info, + self.undeploy_model, default_timeout=None, client_info=client_info, ), - } @property @@ -152,69 +139,70 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_endpoint(self) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_endpoint(self) -> typing.Callable[ - [endpoint_service.GetEndpointRequest], - typing.Union[ - endpoint.Endpoint, - typing.Awaitable[endpoint.Endpoint] - ]]: + def get_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def list_endpoints(self) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], - typing.Union[ - endpoint_service.ListEndpointsResponse, - typing.Awaitable[endpoint_service.ListEndpointsResponse] - ]]: + def list_endpoints( + self, + ) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse], + ], + ]: raise NotImplementedError() @property - def update_endpoint(self) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], - typing.Union[ - gca_endpoint.Endpoint, - typing.Awaitable[gca_endpoint.Endpoint] - ]]: + def update_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], + ]: raise NotImplementedError() @property - def delete_endpoint(self) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def deploy_model(self) -> typing.Callable[ - [endpoint_service.DeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def deploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def undeploy_model(self) -> typing.Callable[ - [endpoint_service.UndeployModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def undeploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'EndpointServiceTransport', -) +__all__ = ("EndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index fbbf33b2b7..a7ca0d8b13 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -45,20 +45,23 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -108,12 +111,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -138,7 +150,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -163,13 +177,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -202,7 +218,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -219,18 +235,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_endpoint(self) -> Callable[ - [endpoint_service.CreateEndpointRequest], - operations.Operation]: + def create_endpoint( + self, + ) -> Callable[[endpoint_service.CreateEndpointRequest], operations.Operation]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -245,18 +261,18 @@ def create_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_endpoint' not in self._stubs: - self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_endpoint'] + return self._stubs["create_endpoint"] @property - def get_endpoint(self) -> Callable[ - [endpoint_service.GetEndpointRequest], - endpoint.Endpoint]: + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -271,18 +287,20 @@ def get_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_endpoint' not in self._stubs: - self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs['get_endpoint'] + return self._stubs["get_endpoint"] @property - def list_endpoints(self) -> Callable[ - [endpoint_service.ListEndpointsRequest], - endpoint_service.ListEndpointsResponse]: + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse + ]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -297,18 +315,18 @@ def list_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_endpoints' not in self._stubs: - self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs['list_endpoints'] + return self._stubs["list_endpoints"] @property - def update_endpoint(self) -> Callable[ - [endpoint_service.UpdateEndpointRequest], - gca_endpoint.Endpoint]: + def update_endpoint( + self, + ) -> Callable[[endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -323,18 +341,18 @@ def update_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_endpoint' not in self._stubs: - self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs['update_endpoint'] + return self._stubs["update_endpoint"] @property - def delete_endpoint(self) -> Callable[ - [endpoint_service.DeleteEndpointRequest], - operations.Operation]: + def delete_endpoint( + self, + ) -> Callable[[endpoint_service.DeleteEndpointRequest], operations.Operation]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -349,18 +367,18 @@ def delete_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_endpoint' not in self._stubs: - self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_endpoint'] + return self._stubs["delete_endpoint"] @property - def deploy_model(self) -> Callable[ - [endpoint_service.DeployModelRequest], - operations.Operation]: + def deploy_model( + self, + ) -> Callable[[endpoint_service.DeployModelRequest], operations.Operation]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -376,18 +394,18 @@ def deploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_model' not in self._stubs: - self._stubs['deploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_model'] + return self._stubs["deploy_model"] @property - def undeploy_model(self) -> Callable[ - [endpoint_service.UndeployModelRequest], - operations.Operation]: + def undeploy_model( + self, + ) -> Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -404,15 +422,13 @@ def undeploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_model' not in self._stubs: - self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_model'] + return self._stubs["undeploy_model"] -__all__ = ( - 'EndpointServiceGrpcTransport', -) +__all__ = ("EndpointServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index 69d7842201..7d743ebb56 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import endpoint @@ -52,13 +52,15 @@ class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -87,21 +89,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -152,12 +156,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -182,7 +195,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -224,18 +239,20 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_endpoint(self) -> Callable[ - [endpoint_service.CreateEndpointRequest], - Awaitable[operations.Operation]]: + def create_endpoint( + self, + ) -> Callable[ + [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the create endpoint method over gRPC. Creates an Endpoint. @@ -250,18 +267,18 @@ def create_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_endpoint' not in self._stubs: - self._stubs['create_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint', + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", request_serializer=endpoint_service.CreateEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_endpoint'] + return self._stubs["create_endpoint"] @property - def get_endpoint(self) -> Callable[ - [endpoint_service.GetEndpointRequest], - Awaitable[endpoint.Endpoint]]: + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: r"""Return a callable for the get endpoint method over gRPC. Gets an Endpoint. @@ -276,18 +293,21 @@ def get_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_endpoint' not in self._stubs: - self._stubs['get_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint', + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", request_serializer=endpoint_service.GetEndpointRequest.serialize, response_deserializer=endpoint.Endpoint.deserialize, ) - return self._stubs['get_endpoint'] + return self._stubs["get_endpoint"] @property - def list_endpoints(self) -> Callable[ - [endpoint_service.ListEndpointsRequest], - Awaitable[endpoint_service.ListEndpointsResponse]]: + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse], + ]: r"""Return a callable for the list endpoints method over gRPC. Lists Endpoints in a Location. @@ -302,18 +322,20 @@ def list_endpoints(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_endpoints' not in self._stubs: - self._stubs['list_endpoints'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints', + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", request_serializer=endpoint_service.ListEndpointsRequest.serialize, response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, ) - return self._stubs['list_endpoints'] + return self._stubs["list_endpoints"] @property - def update_endpoint(self) -> Callable[ - [endpoint_service.UpdateEndpointRequest], - Awaitable[gca_endpoint.Endpoint]]: + def update_endpoint( + self, + ) -> Callable[ + [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] + ]: r"""Return a callable for the update endpoint method over gRPC. Updates an Endpoint. @@ -328,18 +350,20 @@ def update_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_endpoint' not in self._stubs: - self._stubs['update_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint', + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", request_serializer=endpoint_service.UpdateEndpointRequest.serialize, response_deserializer=gca_endpoint.Endpoint.deserialize, ) - return self._stubs['update_endpoint'] + return self._stubs["update_endpoint"] @property - def delete_endpoint(self) -> Callable[ - [endpoint_service.DeleteEndpointRequest], - Awaitable[operations.Operation]]: + def delete_endpoint( + self, + ) -> Callable[ + [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete endpoint method over gRPC. Deletes an Endpoint. @@ -354,18 +378,20 @@ def delete_endpoint(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_endpoint' not in self._stubs: - self._stubs['delete_endpoint'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint', + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", request_serializer=endpoint_service.DeleteEndpointRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_endpoint'] + return self._stubs["delete_endpoint"] @property - def deploy_model(self) -> Callable[ - [endpoint_service.DeployModelRequest], - Awaitable[operations.Operation]]: + def deploy_model( + self, + ) -> Callable[ + [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the deploy model method over gRPC. Deploys a Model into this Endpoint, creating a @@ -381,18 +407,20 @@ def deploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'deploy_model' not in self._stubs: - self._stubs['deploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel', + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", request_serializer=endpoint_service.DeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['deploy_model'] + return self._stubs["deploy_model"] @property - def undeploy_model(self) -> Callable[ - [endpoint_service.UndeployModelRequest], - Awaitable[operations.Operation]]: + def undeploy_model( + self, + ) -> Callable[ + [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the undeploy model method over gRPC. Undeploys a Model from an Endpoint, removing a @@ -409,15 +437,13 @@ def undeploy_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'undeploy_model' not in self._stubs: - self._stubs['undeploy_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel', + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", request_serializer=endpoint_service.UndeployModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['undeploy_model'] + return self._stubs["undeploy_model"] -__all__ = ( - 'EndpointServiceGrpcAsyncIOTransport', -) +__all__ = ("EndpointServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py index 037407b714..5f157047f5 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import JobServiceAsyncClient __all__ = ( - 'JobServiceClient', - 'JobServiceAsyncClient', + "JobServiceClient", + "JobServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index d309df53a5..ca5f400eaa 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -21,25 +21,31 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources @@ -66,32 +72,48 @@ class JobServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) - parse_batch_prediction_job_path = staticmethod(JobServiceClient.parse_batch_prediction_job_path) + parse_batch_prediction_job_path = staticmethod( + JobServiceClient.parse_batch_prediction_job_path + ) custom_job_path = staticmethod(JobServiceClient.custom_job_path) parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) - parse_data_labeling_job_path = staticmethod(JobServiceClient.parse_data_labeling_job_path) + parse_data_labeling_job_path = staticmethod( + JobServiceClient.parse_data_labeling_job_path + ) dataset_path = staticmethod(JobServiceClient.dataset_path) parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) - hyperparameter_tuning_job_path = staticmethod(JobServiceClient.hyperparameter_tuning_job_path) - parse_hyperparameter_tuning_job_path = staticmethod(JobServiceClient.parse_hyperparameter_tuning_job_path) + hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.hyperparameter_tuning_job_path + ) + parse_hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.parse_hyperparameter_tuning_job_path + ) model_path = staticmethod(JobServiceClient.model_path) parse_model_path = staticmethod(JobServiceClient.parse_model_path) - common_billing_account_path = staticmethod(JobServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(JobServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + JobServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + JobServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(JobServiceClient.common_folder_path) parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) common_organization_path = staticmethod(JobServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(JobServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + JobServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(JobServiceClient.common_project_path) parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) common_location_path = staticmethod(JobServiceClient.common_location_path) - parse_common_location_path = staticmethod(JobServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + JobServiceClient.parse_common_location_path + ) from_service_account_file = JobServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -105,14 +127,18 @@ def transport(self) -> JobServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(JobServiceClient).get_transport_class, type(JobServiceClient)) + get_transport_class = functools.partial( + type(JobServiceClient).get_transport_class, type(JobServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -151,18 +177,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + async def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -205,8 +231,10 @@ async def create_custom_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, custom_job]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateCustomJobRequest(request) @@ -229,30 +257,24 @@ async def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + async def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -288,8 +310,10 @@ async def get_custom_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetCustomJobRequest(request) @@ -310,30 +334,24 @@ async def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsAsyncPager: + async def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: r"""Lists CustomJobs in a Location. Args: @@ -367,8 +385,10 @@ async def list_custom_jobs(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListCustomJobsRequest(request) @@ -389,39 +409,30 @@ async def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListCustomJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a CustomJob. Args: @@ -467,8 +478,10 @@ async def delete_custom_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteCustomJobRequest(request) @@ -489,18 +502,11 @@ async def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -513,14 +519,15 @@ async def delete_custom_job(self, # Done; return the response. return response - async def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -556,8 +563,10 @@ async def cancel_custom_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelCustomJobRequest(request) @@ -578,28 +587,24 @@ async def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -636,8 +641,10 @@ async def create_data_labeling_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, data_labeling_job]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateDataLabelingJobRequest(request) @@ -660,30 +667,24 @@ async def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + async def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -715,8 +716,10 @@ async def get_data_labeling_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetDataLabelingJobRequest(request) @@ -737,30 +740,24 @@ async def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsAsyncPager: + async def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -793,8 +790,10 @@ async def list_data_labeling_jobs(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListDataLabelingJobsRequest(request) @@ -815,39 +814,30 @@ async def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListDataLabelingJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a DataLabelingJob. Args: @@ -894,8 +884,10 @@ async def delete_data_labeling_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteDataLabelingJobRequest(request) @@ -916,18 +908,11 @@ async def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -940,14 +925,15 @@ async def delete_data_labeling_job(self, # Done; return the response. return response - async def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -973,8 +959,10 @@ async def cancel_data_labeling_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelDataLabelingJobRequest(request) @@ -995,28 +983,24 @@ async def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1055,8 +1039,10 @@ async def create_hyperparameter_tuning_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, hyperparameter_tuning_job]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateHyperparameterTuningJobRequest(request) @@ -1079,30 +1065,24 @@ async def create_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + async def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1136,8 +1116,10 @@ async def get_hyperparameter_tuning_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetHyperparameterTuningJobRequest(request) @@ -1158,30 +1140,24 @@ async def get_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + async def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1215,8 +1191,10 @@ async def list_hyperparameter_tuning_jobs(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListHyperparameterTuningJobsRequest(request) @@ -1237,39 +1215,30 @@ async def list_hyperparameter_tuning_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListHyperparameterTuningJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1316,8 +1285,10 @@ async def delete_hyperparameter_tuning_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteHyperparameterTuningJobRequest(request) @@ -1338,18 +1309,11 @@ async def delete_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1362,14 +1326,15 @@ async def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - async def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1408,8 +1373,10 @@ async def cancel_hyperparameter_tuning_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelHyperparameterTuningJobRequest(request) @@ -1430,28 +1397,24 @@ async def cancel_hyperparameter_tuning_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - async def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1494,8 +1457,10 @@ async def create_batch_prediction_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, batch_prediction_job]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CreateBatchPredictionJobRequest(request) @@ -1518,30 +1483,24 @@ async def create_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + async def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1578,8 +1537,10 @@ async def get_batch_prediction_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.GetBatchPredictionJobRequest(request) @@ -1600,30 +1561,24 @@ async def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsAsyncPager: + async def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1657,8 +1612,10 @@ async def list_batch_prediction_jobs(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.ListBatchPredictionJobsRequest(request) @@ -1679,39 +1636,30 @@ async def list_batch_prediction_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListBatchPredictionJobsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -1759,8 +1707,10 @@ async def delete_batch_prediction_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.DeleteBatchPredictionJobRequest(request) @@ -1781,18 +1731,11 @@ async def delete_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -1805,14 +1748,15 @@ async def delete_batch_prediction_job(self, # Done; return the response. return response - async def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -1849,8 +1793,10 @@ async def cancel_batch_prediction_job(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = job_service.CancelBatchPredictionJobRequest(request) @@ -1871,35 +1817,23 @@ async def cancel_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceAsyncClient', -) +__all__ = ("JobServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index ca78580819..cf840174c5 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -23,27 +23,33 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources @@ -68,13 +74,12 @@ class JobServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] - _transport_registry['grpc'] = JobServiceGrpcTransport - _transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[JobServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: """Return an appropriate transport class. Args: @@ -125,7 +130,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -144,9 +149,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -161,132 +165,178 @@ def transport(self) -> JobServiceTransport: return self._transport @staticmethod - def batch_prediction_job_path(project: str,location: str,batch_prediction_job: str,) -> str: + def batch_prediction_job_path( + project: str, location: str, batch_prediction_job: str, + ) -> str: """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, + location=location, + batch_prediction_job=batch_prediction_job, + ) @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str,str]: + def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: """Parse a batch_prediction_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def custom_job_path(project: str,location: str,custom_job: str,) -> str: + def custom_job_path(project: str, location: str, custom_job: str,) -> str: """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) @staticmethod - def parse_custom_job_path(path: str) -> Dict[str,str]: + def parse_custom_job_path(path: str) -> Dict[str, str]: """Parse a custom_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def data_labeling_job_path(project: str,location: str,data_labeling_job: str,) -> str: + def data_labeling_job_path( + project: str, location: str, data_labeling_job: str, + ) -> str: """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str,str]: + def parse_data_labeling_job_path(path: str) -> Dict[str, str]: """Parse a data_labeling_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def hyperparameter_tuning_job_path(project: str,location: str,hyperparameter_tuning_job: str,) -> str: + def hyperparameter_tuning_job_path( + project: str, location: str, hyperparameter_tuning_job: str, + ) -> str: """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str,str]: + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the job service client. Args: @@ -330,7 +380,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -358,7 +410,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -370,8 +424,10 @@ def __init__(self, *, if isinstance(transport, JobServiceTransport): # transport is a JobServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -390,15 +446,16 @@ def __init__(self, *, client_info=client_info, ) - def create_custom_job(self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: + def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: r"""Creates a CustomJob. A created CustomJob right away will be attempted to be run. @@ -442,8 +499,10 @@ def create_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, custom_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateCustomJobRequest. @@ -467,30 +526,24 @@ def create_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_custom_job(self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: + def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: r"""Gets a CustomJob. Args: @@ -527,8 +580,10 @@ def get_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetCustomJobRequest. @@ -550,30 +605,24 @@ def get_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_custom_jobs(self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsPager: + def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: r"""Lists CustomJobs in a Location. Args: @@ -608,8 +657,10 @@ def list_custom_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListCustomJobsRequest. @@ -631,39 +682,30 @@ def list_custom_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_custom_job(self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a CustomJob. Args: @@ -710,8 +752,10 @@ def delete_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteCustomJobRequest. @@ -733,18 +777,11 @@ def delete_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -757,14 +794,15 @@ def delete_custom_job(self, # Done; return the response. return response - def cancel_custom_job(self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a CustomJob. Starts asynchronous cancellation on the CustomJob. The server makes a best effort to cancel the job, but success is not guaranteed. Clients can use @@ -801,8 +839,10 @@ def cancel_custom_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelCustomJobRequest. @@ -824,28 +864,24 @@ def cancel_custom_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_data_labeling_job(self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: + def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: r"""Creates a DataLabelingJob. Args: @@ -883,8 +919,10 @@ def create_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, data_labeling_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateDataLabelingJobRequest. @@ -908,30 +946,24 @@ def create_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_data_labeling_job(self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: + def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: r"""Gets a DataLabelingJob. Args: @@ -964,8 +996,10 @@ def get_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetDataLabelingJobRequest. @@ -987,30 +1021,24 @@ def get_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_data_labeling_jobs(self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsPager: + def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: r"""Lists DataLabelingJobs in a Location. Args: @@ -1044,8 +1072,10 @@ def list_data_labeling_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListDataLabelingJobsRequest. @@ -1067,39 +1097,30 @@ def list_data_labeling_jobs(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_data_labeling_job(self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a DataLabelingJob. Args: @@ -1147,8 +1168,10 @@ def delete_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteDataLabelingJobRequest. @@ -1170,18 +1193,11 @@ def delete_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1194,14 +1210,15 @@ def delete_data_labeling_job(self, # Done; return the response. return response - def cancel_data_labeling_job(self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a DataLabelingJob. Success of cancellation is not guaranteed. @@ -1228,8 +1245,10 @@ def cancel_data_labeling_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelDataLabelingJobRequest. @@ -1251,28 +1270,24 @@ def cancel_data_labeling_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_hyperparameter_tuning_job(self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: r"""Creates a HyperparameterTuningJob Args: @@ -1312,8 +1327,10 @@ def create_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, hyperparameter_tuning_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateHyperparameterTuningJobRequest. @@ -1332,35 +1349,31 @@ def create_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_hyperparameter_tuning_job(self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: r"""Gets a HyperparameterTuningJob Args: @@ -1395,8 +1408,10 @@ def get_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetHyperparameterTuningJobRequest. @@ -1413,35 +1428,31 @@ def get_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_hyperparameter_tuning_jobs(self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsPager: + def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: r"""Lists HyperparameterTuningJobs in a Location. Args: @@ -1476,8 +1487,10 @@ def list_hyperparameter_tuning_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListHyperparameterTuningJobsRequest. @@ -1494,44 +1507,37 @@ def list_hyperparameter_tuning_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_hyperparameter_tuning_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_hyperparameter_tuning_job(self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a HyperparameterTuningJob. Args: @@ -1579,8 +1585,10 @@ def delete_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteHyperparameterTuningJobRequest. @@ -1597,23 +1605,18 @@ def delete_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -1626,14 +1629,15 @@ def delete_hyperparameter_tuning_job(self, # Done; return the response. return response - def cancel_hyperparameter_tuning_job(self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a HyperparameterTuningJob. Starts asynchronous cancellation on the HyperparameterTuningJob. The server makes a best effort to cancel the job, but success is not guaranteed. @@ -1673,8 +1677,10 @@ def cancel_hyperparameter_tuning_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelHyperparameterTuningJobRequest. @@ -1691,33 +1697,31 @@ def cancel_hyperparameter_tuning_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_hyperparameter_tuning_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - def create_batch_prediction_job(self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: + def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: r"""Creates a BatchPredictionJob. A BatchPredictionJob once created will right away be attempted to start. @@ -1761,8 +1765,10 @@ def create_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, batch_prediction_job]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CreateBatchPredictionJobRequest. @@ -1781,35 +1787,31 @@ def create_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.create_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_batch_prediction_job(self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: + def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: r"""Gets a BatchPredictionJob Args: @@ -1847,8 +1849,10 @@ def get_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.GetBatchPredictionJobRequest. @@ -1870,30 +1874,24 @@ def get_batch_prediction_job(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_batch_prediction_jobs(self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsPager: + def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: r"""Lists BatchPredictionJobs in a Location. Args: @@ -1928,8 +1926,10 @@ def list_batch_prediction_jobs(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.ListBatchPredictionJobsRequest. @@ -1946,44 +1946,37 @@ def list_batch_prediction_jobs(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_batch_prediction_jobs] + rpc = self._transport._wrapped_methods[ + self._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_batch_prediction_job(self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a BatchPredictionJob. Can only be called on jobs that already finished. @@ -2032,8 +2025,10 @@ def delete_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.DeleteBatchPredictionJobRequest. @@ -2050,23 +2045,18 @@ def delete_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.delete_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -2079,14 +2069,15 @@ def delete_batch_prediction_job(self, # Done; return the response. return response - def cancel_batch_prediction_job(self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a BatchPredictionJob. Starts asynchronous cancellation on the BatchPredictionJob. The @@ -2124,8 +2115,10 @@ def cancel_batch_prediction_job(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a job_service.CancelBatchPredictionJobRequest. @@ -2142,40 +2135,30 @@ def cancel_batch_prediction_job(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_batch_prediction_job] + rpc = self._transport._wrapped_methods[ + self._transport.cancel_batch_prediction_job + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'JobServiceClient', -) +__all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 17cf187f9e..05e5be73ca 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -41,12 +41,15 @@ class ListCustomJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListCustomJobsResponse], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListCustomJobsResponse], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -80,7 +83,7 @@ def __iter__(self) -> Iterable[custom_job.CustomJob]: yield from page.custom_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListCustomJobsAsyncPager: @@ -100,12 +103,15 @@ class ListCustomJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -143,7 +149,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataLabelingJobsPager: @@ -163,12 +169,15 @@ class ListDataLabelingJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListDataLabelingJobsResponse], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListDataLabelingJobsResponse], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -202,7 +211,7 @@ def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: yield from page.data_labeling_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListDataLabelingJobsAsyncPager: @@ -222,12 +231,15 @@ class ListDataLabelingJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -265,7 +277,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsPager: @@ -285,12 +297,15 @@ class ListHyperparameterTuningJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -324,7 +339,7 @@ def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob yield from page.hyperparameter_tuning_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListHyperparameterTuningJobsAsyncPager: @@ -344,12 +359,17 @@ class ListHyperparameterTuningJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListHyperparameterTuningJobsResponse]], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -371,14 +391,18 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + async def pages( + self, + ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response - def __aiter__(self) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + def __aiter__( + self, + ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: async def async_generator(): async for page in self.pages: for response in page.hyperparameter_tuning_jobs: @@ -387,7 +411,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListBatchPredictionJobsPager: @@ -407,12 +431,15 @@ class ListBatchPredictionJobsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., job_service.ListBatchPredictionJobsResponse], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., job_service.ListBatchPredictionJobsResponse], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -446,7 +473,7 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: yield from page.batch_prediction_jobs def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListBatchPredictionJobsAsyncPager: @@ -466,12 +493,15 @@ class ListBatchPredictionJobsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -509,4 +539,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py index f46fff0524..ca4d929cb5 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py @@ -25,12 +25,12 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] -_transport_registry['grpc'] = JobServiceGrpcTransport -_transport_registry['grpc_asyncio'] = JobServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = JobServiceGrpcTransport +_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport __all__ = ( - 'JobServiceTransport', - 'JobServiceGrpcTransport', - 'JobServiceGrpcAsyncIOTransport', + "JobServiceTransport", + "JobServiceGrpcTransport", + "JobServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index 5bc1354ad9..04c05890bc 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -21,19 +21,25 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -42,29 +48,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -87,24 +93,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -116,29 +124,19 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_custom_job: gapic_v1.method.wrap_method( - self.create_custom_job, - default_timeout=None, - client_info=client_info, + self.create_custom_job, default_timeout=None, client_info=client_info, ), self.get_custom_job: gapic_v1.method.wrap_method( - self.get_custom_job, - default_timeout=None, - client_info=client_info, + self.get_custom_job, default_timeout=None, client_info=client_info, ), self.list_custom_jobs: gapic_v1.method.wrap_method( - self.list_custom_jobs, - default_timeout=None, - client_info=client_info, + self.list_custom_jobs, default_timeout=None, client_info=client_info, ), self.delete_custom_job: gapic_v1.method.wrap_method( - self.delete_custom_job, - default_timeout=None, - client_info=client_info, + self.delete_custom_job, default_timeout=None, client_info=client_info, ), self.cancel_custom_job: gapic_v1.method.wrap_method( - self.cancel_custom_job, - default_timeout=None, - client_info=client_info, + self.cancel_custom_job, default_timeout=None, client_info=client_info, ), self.create_data_labeling_job: gapic_v1.method.wrap_method( self.create_data_labeling_job, @@ -215,7 +213,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -224,186 +221,216 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_custom_job(self) -> typing.Callable[ - [job_service.CreateCustomJobRequest], - typing.Union[ - gca_custom_job.CustomJob, - typing.Awaitable[gca_custom_job.CustomJob] - ]]: + def create_custom_job( + self, + ) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] + ], + ]: raise NotImplementedError() @property - def get_custom_job(self) -> typing.Callable[ - [job_service.GetCustomJobRequest], - typing.Union[ - custom_job.CustomJob, - typing.Awaitable[custom_job.CustomJob] - ]]: + def get_custom_job( + self, + ) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], + ]: raise NotImplementedError() @property - def list_custom_jobs(self) -> typing.Callable[ - [job_service.ListCustomJobsRequest], - typing.Union[ - job_service.ListCustomJobsResponse, - typing.Awaitable[job_service.ListCustomJobsResponse] - ]]: + def list_custom_jobs( + self, + ) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_custom_job(self) -> typing.Callable[ - [job_service.DeleteCustomJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_custom_job( + self, + ) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_custom_job(self) -> typing.Callable[ - [job_service.CancelCustomJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_custom_job( + self, + ) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_data_labeling_job(self) -> typing.Callable[ - [job_service.CreateDataLabelingJobRequest], - typing.Union[ - gca_data_labeling_job.DataLabelingJob, - typing.Awaitable[gca_data_labeling_job.DataLabelingJob] - ]]: + def create_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def get_data_labeling_job(self) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], - typing.Union[ - data_labeling_job.DataLabelingJob, - typing.Awaitable[data_labeling_job.DataLabelingJob] - ]]: + def get_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob], + ], + ]: raise NotImplementedError() @property - def list_data_labeling_jobs(self) -> typing.Callable[ - [job_service.ListDataLabelingJobsRequest], - typing.Union[ - job_service.ListDataLabelingJobsResponse, - typing.Awaitable[job_service.ListDataLabelingJobsResponse] - ]]: + def list_data_labeling_jobs( + self, + ) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_data_labeling_job(self) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_data_labeling_job(self) -> typing.Callable[ - [job_service.CancelDataLabelingJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - typing.Union[ - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def create_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def get_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.GetHyperparameterTuningJobRequest], - typing.Union[ - hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob] - ]]: + def get_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: raise NotImplementedError() @property - def list_hyperparameter_tuning_jobs(self) -> typing.Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - typing.Union[ - job_service.ListHyperparameterTuningJobsResponse, - typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse] - ]]: + def list_hyperparameter_tuning_jobs( + self, + ) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_hyperparameter_tuning_job(self) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() @property - def create_batch_prediction_job(self) -> typing.Callable[ - [job_service.CreateBatchPredictionJobRequest], - typing.Union[ - gca_batch_prediction_job.BatchPredictionJob, - typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob] - ]]: + def create_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def get_batch_prediction_job(self) -> typing.Callable[ - [job_service.GetBatchPredictionJobRequest], - typing.Union[ - batch_prediction_job.BatchPredictionJob, - typing.Awaitable[batch_prediction_job.BatchPredictionJob] - ]]: + def get_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob], + ], + ]: raise NotImplementedError() @property - def list_batch_prediction_jobs(self) -> typing.Callable[ - [job_service.ListBatchPredictionJobsRequest], - typing.Union[ - job_service.ListBatchPredictionJobsResponse, - typing.Awaitable[job_service.ListBatchPredictionJobsResponse] - ]]: + def list_batch_prediction_jobs( + self, + ) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse], + ], + ]: raise NotImplementedError() @property - def delete_batch_prediction_job(self) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_batch_prediction_job(self) -> typing.Callable[ - [job_service.CancelBatchPredictionJobRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() -__all__ = ( - 'JobServiceTransport', -) +__all__ = ("JobServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index 8523b62d35..246b11b5d6 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -18,23 +18,29 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -54,20 +60,23 @@ class JobServiceGrpcTransport(JobServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -117,12 +126,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -147,7 +165,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -172,13 +192,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -211,7 +233,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -228,18 +250,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_custom_job(self) -> Callable[ - [job_service.CreateCustomJobRequest], - gca_custom_job.CustomJob]: + def create_custom_job( + self, + ) -> Callable[[job_service.CreateCustomJobRequest], gca_custom_job.CustomJob]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -255,18 +277,18 @@ def create_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_custom_job' not in self._stubs: - self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs['create_custom_job'] + return self._stubs["create_custom_job"] @property - def get_custom_job(self) -> Callable[ - [job_service.GetCustomJobRequest], - custom_job.CustomJob]: + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -281,18 +303,20 @@ def get_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_custom_job' not in self._stubs: - self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs['get_custom_job'] + return self._stubs["get_custom_job"] @property - def list_custom_jobs(self) -> Callable[ - [job_service.ListCustomJobsRequest], - job_service.ListCustomJobsResponse]: + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse + ]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -307,18 +331,18 @@ def list_custom_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_custom_jobs' not in self._stubs: - self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs['list_custom_jobs'] + return self._stubs["list_custom_jobs"] @property - def delete_custom_job(self) -> Callable[ - [job_service.DeleteCustomJobRequest], - operations.Operation]: + def delete_custom_job( + self, + ) -> Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -333,18 +357,18 @@ def delete_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_custom_job' not in self._stubs: - self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_custom_job'] + return self._stubs["delete_custom_job"] @property - def cancel_custom_job(self) -> Callable[ - [job_service.CancelCustomJobRequest], - empty.Empty]: + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], empty.Empty]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -371,18 +395,21 @@ def cancel_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_custom_job' not in self._stubs: - self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_custom_job'] + return self._stubs["cancel_custom_job"] @property - def create_data_labeling_job(self) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob]: + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + gca_data_labeling_job.DataLabelingJob, + ]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -397,18 +424,20 @@ def create_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_data_labeling_job' not in self._stubs: - self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['create_data_labeling_job'] + return self._stubs["create_data_labeling_job"] @property - def get_data_labeling_job(self) -> Callable[ - [job_service.GetDataLabelingJobRequest], - data_labeling_job.DataLabelingJob]: + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob + ]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -423,18 +452,21 @@ def get_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_data_labeling_job' not in self._stubs: - self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['get_data_labeling_job'] + return self._stubs["get_data_labeling_job"] @property - def list_data_labeling_jobs(self) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse]: + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse, + ]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -449,18 +481,18 @@ def list_data_labeling_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_labeling_jobs' not in self._stubs: - self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs['list_data_labeling_jobs'] + return self._stubs["list_data_labeling_jobs"] @property - def delete_data_labeling_job(self) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], - operations.Operation]: + def delete_data_labeling_job( + self, + ) -> Callable[[job_service.DeleteDataLabelingJobRequest], operations.Operation]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -475,18 +507,18 @@ def delete_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_data_labeling_job' not in self._stubs: - self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_data_labeling_job'] + return self._stubs["delete_data_labeling_job"] @property - def cancel_data_labeling_job(self) -> Callable[ - [job_service.CancelDataLabelingJobRequest], - empty.Empty]: + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -502,18 +534,21 @@ def cancel_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_data_labeling_job' not in self._stubs: - self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_data_labeling_job'] + return self._stubs["cancel_data_labeling_job"] @property - def create_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob]: + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + ]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -529,18 +564,23 @@ def create_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_hyperparameter_tuning_job' not in self._stubs: - self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['create_hyperparameter_tuning_job'] + return self._stubs["create_hyperparameter_tuning_job"] @property - def get_hyperparameter_tuning_job(self) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob]: + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + hyperparameter_tuning_job.HyperparameterTuningJob, + ]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -555,18 +595,23 @@ def get_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_hyperparameter_tuning_job' not in self._stubs: - self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['get_hyperparameter_tuning_job'] + return self._stubs["get_hyperparameter_tuning_job"] @property - def list_hyperparameter_tuning_jobs(self) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse]: + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse, + ]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -582,18 +627,22 @@ def list_hyperparameter_tuning_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_hyperparameter_tuning_jobs' not in self._stubs: - self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs['list_hyperparameter_tuning_jobs'] + return self._stubs["list_hyperparameter_tuning_jobs"] @property - def delete_hyperparameter_tuning_job(self) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - operations.Operation]: + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation + ]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -609,18 +658,20 @@ def delete_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_hyperparameter_tuning_job' not in self._stubs: - self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_hyperparameter_tuning_job'] + return self._stubs["delete_hyperparameter_tuning_job"] @property - def cancel_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - empty.Empty]: + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[[job_service.CancelHyperparameterTuningJobRequest], empty.Empty]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -649,18 +700,23 @@ def cancel_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_hyperparameter_tuning_job' not in self._stubs: - self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_hyperparameter_tuning_job'] + return self._stubs["cancel_hyperparameter_tuning_job"] @property - def create_batch_prediction_job(self) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob]: + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + gca_batch_prediction_job.BatchPredictionJob, + ]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -676,18 +732,21 @@ def create_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_batch_prediction_job' not in self._stubs: - self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['create_batch_prediction_job'] + return self._stubs["create_batch_prediction_job"] @property - def get_batch_prediction_job(self) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob]: + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + batch_prediction_job.BatchPredictionJob, + ]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -702,18 +761,21 @@ def get_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_batch_prediction_job' not in self._stubs: - self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['get_batch_prediction_job'] + return self._stubs["get_batch_prediction_job"] @property - def list_batch_prediction_jobs(self) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse]: + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse, + ]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -728,18 +790,18 @@ def list_batch_prediction_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_batch_prediction_jobs' not in self._stubs: - self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs['list_batch_prediction_jobs'] + return self._stubs["list_batch_prediction_jobs"] @property - def delete_batch_prediction_job(self) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], - operations.Operation]: + def delete_batch_prediction_job( + self, + ) -> Callable[[job_service.DeleteBatchPredictionJobRequest], operations.Operation]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -755,18 +817,18 @@ def delete_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_batch_prediction_job' not in self._stubs: - self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_batch_prediction_job'] + return self._stubs["delete_batch_prediction_job"] @property - def cancel_batch_prediction_job(self) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], - empty.Empty]: + def cancel_batch_prediction_job( + self, + ) -> Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -792,15 +854,13 @@ def cancel_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_batch_prediction_job' not in self._stubs: - self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_batch_prediction_job'] + return self._stubs["cancel_batch_prediction_job"] -__all__ = ( - 'JobServiceGrpcTransport', -) +__all__ = ("JobServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index ac8e04e542..428b37f268 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -18,24 +18,30 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import job_service from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -61,13 +67,15 @@ class JobServiceGrpcAsyncIOTransport(JobServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -96,21 +104,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -161,12 +171,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -191,7 +210,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -233,18 +254,20 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_custom_job(self) -> Callable[ - [job_service.CreateCustomJobRequest], - Awaitable[gca_custom_job.CustomJob]]: + def create_custom_job( + self, + ) -> Callable[ + [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] + ]: r"""Return a callable for the create custom job method over gRPC. Creates a CustomJob. A created CustomJob right away @@ -260,18 +283,18 @@ def create_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_custom_job' not in self._stubs: - self._stubs['create_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob', + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", request_serializer=job_service.CreateCustomJobRequest.serialize, response_deserializer=gca_custom_job.CustomJob.deserialize, ) - return self._stubs['create_custom_job'] + return self._stubs["create_custom_job"] @property - def get_custom_job(self) -> Callable[ - [job_service.GetCustomJobRequest], - Awaitable[custom_job.CustomJob]]: + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: r"""Return a callable for the get custom job method over gRPC. Gets a CustomJob. @@ -286,18 +309,21 @@ def get_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_custom_job' not in self._stubs: - self._stubs['get_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob', + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", request_serializer=job_service.GetCustomJobRequest.serialize, response_deserializer=custom_job.CustomJob.deserialize, ) - return self._stubs['get_custom_job'] + return self._stubs["get_custom_job"] @property - def list_custom_jobs(self) -> Callable[ - [job_service.ListCustomJobsRequest], - Awaitable[job_service.ListCustomJobsResponse]]: + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse], + ]: r"""Return a callable for the list custom jobs method over gRPC. Lists CustomJobs in a Location. @@ -312,18 +338,20 @@ def list_custom_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_custom_jobs' not in self._stubs: - self._stubs['list_custom_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs', + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", request_serializer=job_service.ListCustomJobsRequest.serialize, response_deserializer=job_service.ListCustomJobsResponse.deserialize, ) - return self._stubs['list_custom_jobs'] + return self._stubs["list_custom_jobs"] @property - def delete_custom_job(self) -> Callable[ - [job_service.DeleteCustomJobRequest], - Awaitable[operations.Operation]]: + def delete_custom_job( + self, + ) -> Callable[ + [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete custom job method over gRPC. Deletes a CustomJob. @@ -338,18 +366,18 @@ def delete_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_custom_job' not in self._stubs: - self._stubs['delete_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob', + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", request_serializer=job_service.DeleteCustomJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_custom_job'] + return self._stubs["delete_custom_job"] @property - def cancel_custom_job(self) -> Callable[ - [job_service.CancelCustomJobRequest], - Awaitable[empty.Empty]]: + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: r"""Return a callable for the cancel custom job method over gRPC. Cancels a CustomJob. Starts asynchronous cancellation on the @@ -376,18 +404,21 @@ def cancel_custom_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_custom_job' not in self._stubs: - self._stubs['cancel_custom_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob', + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", request_serializer=job_service.CancelCustomJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_custom_job'] + return self._stubs["cancel_custom_job"] @property - def create_data_labeling_job(self) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - Awaitable[gca_data_labeling_job.DataLabelingJob]]: + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob], + ]: r"""Return a callable for the create data labeling job method over gRPC. Creates a DataLabelingJob. @@ -402,18 +433,21 @@ def create_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_data_labeling_job' not in self._stubs: - self._stubs['create_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob', + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", request_serializer=job_service.CreateDataLabelingJobRequest.serialize, response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['create_data_labeling_job'] + return self._stubs["create_data_labeling_job"] @property - def get_data_labeling_job(self) -> Callable[ - [job_service.GetDataLabelingJobRequest], - Awaitable[data_labeling_job.DataLabelingJob]]: + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob], + ]: r"""Return a callable for the get data labeling job method over gRPC. Gets a DataLabelingJob. @@ -428,18 +462,21 @@ def get_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_data_labeling_job' not in self._stubs: - self._stubs['get_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob', + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", request_serializer=job_service.GetDataLabelingJobRequest.serialize, response_deserializer=data_labeling_job.DataLabelingJob.deserialize, ) - return self._stubs['get_data_labeling_job'] + return self._stubs["get_data_labeling_job"] @property - def list_data_labeling_jobs(self) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - Awaitable[job_service.ListDataLabelingJobsResponse]]: + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse], + ]: r"""Return a callable for the list data labeling jobs method over gRPC. Lists DataLabelingJobs in a Location. @@ -454,18 +491,20 @@ def list_data_labeling_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_data_labeling_jobs' not in self._stubs: - self._stubs['list_data_labeling_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs', + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", request_serializer=job_service.ListDataLabelingJobsRequest.serialize, response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, ) - return self._stubs['list_data_labeling_jobs'] + return self._stubs["list_data_labeling_jobs"] @property - def delete_data_labeling_job(self) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], - Awaitable[operations.Operation]]: + def delete_data_labeling_job( + self, + ) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete data labeling job method over gRPC. Deletes a DataLabelingJob. @@ -480,18 +519,18 @@ def delete_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_data_labeling_job' not in self._stubs: - self._stubs['delete_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob', + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_data_labeling_job'] + return self._stubs["delete_data_labeling_job"] @property - def cancel_data_labeling_job(self) -> Callable[ - [job_service.CancelDataLabelingJobRequest], - Awaitable[empty.Empty]]: + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: r"""Return a callable for the cancel data labeling job method over gRPC. Cancels a DataLabelingJob. Success of cancellation is @@ -507,18 +546,21 @@ def cancel_data_labeling_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_data_labeling_job' not in self._stubs: - self._stubs['cancel_data_labeling_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob', + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", request_serializer=job_service.CancelDataLabelingJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_data_labeling_job'] + return self._stubs["cancel_data_labeling_job"] @property - def create_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob]]: + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ]: r"""Return a callable for the create hyperparameter tuning job method over gRPC. @@ -534,18 +576,23 @@ def create_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_hyperparameter_tuning_job' not in self._stubs: - self._stubs['create_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob', + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['create_hyperparameter_tuning_job'] + return self._stubs["create_hyperparameter_tuning_job"] @property - def get_hyperparameter_tuning_job(self) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob]]: + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ]: r"""Return a callable for the get hyperparameter tuning job method over gRPC. Gets a HyperparameterTuningJob @@ -560,18 +607,23 @@ def get_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_hyperparameter_tuning_job' not in self._stubs: - self._stubs['get_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob', + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, ) - return self._stubs['get_hyperparameter_tuning_job'] + return self._stubs["get_hyperparameter_tuning_job"] @property - def list_hyperparameter_tuning_jobs(self) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - Awaitable[job_service.ListHyperparameterTuningJobsResponse]]: + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ]: r"""Return a callable for the list hyperparameter tuning jobs method over gRPC. @@ -587,18 +639,23 @@ def list_hyperparameter_tuning_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_hyperparameter_tuning_jobs' not in self._stubs: - self._stubs['list_hyperparameter_tuning_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs', + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, ) - return self._stubs['list_hyperparameter_tuning_jobs'] + return self._stubs["list_hyperparameter_tuning_jobs"] @property - def delete_hyperparameter_tuning_job(self) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - Awaitable[operations.Operation]]: + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete hyperparameter tuning job method over gRPC. @@ -614,18 +671,22 @@ def delete_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_hyperparameter_tuning_job' not in self._stubs: - self._stubs['delete_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob', + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_hyperparameter_tuning_job'] + return self._stubs["delete_hyperparameter_tuning_job"] @property - def cancel_hyperparameter_tuning_job(self) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - Awaitable[empty.Empty]]: + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel hyperparameter tuning job method over gRPC. @@ -654,18 +715,23 @@ def cancel_hyperparameter_tuning_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_hyperparameter_tuning_job' not in self._stubs: - self._stubs['cancel_hyperparameter_tuning_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob', + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_hyperparameter_tuning_job'] + return self._stubs["cancel_hyperparameter_tuning_job"] @property - def create_batch_prediction_job(self) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - Awaitable[gca_batch_prediction_job.BatchPredictionJob]]: + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ]: r"""Return a callable for the create batch prediction job method over gRPC. Creates a BatchPredictionJob. A BatchPredictionJob @@ -681,18 +747,21 @@ def create_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_batch_prediction_job' not in self._stubs: - self._stubs['create_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob', + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['create_batch_prediction_job'] + return self._stubs["create_batch_prediction_job"] @property - def get_batch_prediction_job(self) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - Awaitable[batch_prediction_job.BatchPredictionJob]]: + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob], + ]: r"""Return a callable for the get batch prediction job method over gRPC. Gets a BatchPredictionJob @@ -707,18 +776,21 @@ def get_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_batch_prediction_job' not in self._stubs: - self._stubs['get_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob', + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", request_serializer=job_service.GetBatchPredictionJobRequest.serialize, response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, ) - return self._stubs['get_batch_prediction_job'] + return self._stubs["get_batch_prediction_job"] @property - def list_batch_prediction_jobs(self) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - Awaitable[job_service.ListBatchPredictionJobsResponse]]: + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse], + ]: r"""Return a callable for the list batch prediction jobs method over gRPC. Lists BatchPredictionJobs in a Location. @@ -733,18 +805,20 @@ def list_batch_prediction_jobs(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_batch_prediction_jobs' not in self._stubs: - self._stubs['list_batch_prediction_jobs'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs', + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, ) - return self._stubs['list_batch_prediction_jobs'] + return self._stubs["list_batch_prediction_jobs"] @property - def delete_batch_prediction_job(self) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], - Awaitable[operations.Operation]]: + def delete_batch_prediction_job( + self, + ) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] + ]: r"""Return a callable for the delete batch prediction job method over gRPC. Deletes a BatchPredictionJob. Can only be called on @@ -760,18 +834,20 @@ def delete_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_batch_prediction_job' not in self._stubs: - self._stubs['delete_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob', + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_batch_prediction_job'] + return self._stubs["delete_batch_prediction_job"] @property - def cancel_batch_prediction_job(self) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], - Awaitable[empty.Empty]]: + def cancel_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel batch prediction job method over gRPC. Cancels a BatchPredictionJob. @@ -797,15 +873,13 @@ def cancel_batch_prediction_job(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_batch_prediction_job' not in self._stubs: - self._stubs['cancel_batch_prediction_job'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob', + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_batch_prediction_job'] + return self._stubs["cancel_batch_prediction_job"] -__all__ = ( - 'JobServiceGrpcAsyncIOTransport', -) +__all__ = ("JobServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py index c533a12b45..1d6216d1f7 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import MigrationServiceAsyncClient __all__ = ( - 'MigrationServiceClient', - 'MigrationServiceAsyncClient', + "MigrationServiceClient", + "MigrationServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index dea3b50632..c9008dc298 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -51,7 +51,9 @@ class MigrationServiceAsyncClient: DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) - parse_annotated_dataset_path = staticmethod(MigrationServiceClient.parse_annotated_dataset_path) + parse_annotated_dataset_path = staticmethod( + MigrationServiceClient.parse_annotated_dataset_path + ) dataset_path = staticmethod(MigrationServiceClient.dataset_path) parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) dataset_path = staticmethod(MigrationServiceClient.dataset_path) @@ -65,20 +67,34 @@ class MigrationServiceAsyncClient: version_path = staticmethod(MigrationServiceClient.version_path) parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) - common_billing_account_path = staticmethod(MigrationServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(MigrationServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + MigrationServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + MigrationServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(MigrationServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + MigrationServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(MigrationServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(MigrationServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + MigrationServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + MigrationServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(MigrationServiceClient.common_project_path) - parse_common_project_path = staticmethod(MigrationServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + MigrationServiceClient.parse_common_project_path + ) common_location_path = staticmethod(MigrationServiceClient.common_location_path) - parse_common_location_path = staticmethod(MigrationServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + MigrationServiceClient.parse_common_location_path + ) from_service_account_file = MigrationServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -92,14 +108,18 @@ def transport(self) -> MigrationServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient)) + get_transport_class = functools.partial( + type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, MigrationServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -138,17 +158,17 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesAsyncPager: + async def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -187,8 +207,10 @@ async def search_migratable_resources(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = migration_service.SearchMigratableResourcesRequest(request) @@ -209,40 +231,33 @@ async def search_migratable_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.SearchMigratableResourcesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -288,8 +303,10 @@ async def batch_migrate_resources(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, migrate_resource_requests]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = migration_service.BatchMigrateResourcesRequest(request) @@ -312,18 +329,11 @@ async def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -337,21 +347,14 @@ async def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceAsyncClient', -) +__all__ = ("MigrationServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 2acb4a0ac7..bf1f8e5c6b 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore @@ -50,13 +50,14 @@ class MigrationServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] - _transport_registry['grpc'] = MigrationServiceGrpcTransport - _transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[MigrationServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry["grpc"] = MigrationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: """Return an appropriate transport class. Args: @@ -110,7 +111,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -129,9 +130,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -146,143 +146,183 @@ def transport(self) -> MigrationServiceTransport: return self._transport @staticmethod - def annotated_dataset_path(project: str,dataset: str,annotated_dataset: str,) -> str: + def annotated_dataset_path( + project: str, dataset: str, annotated_dataset: str, + ) -> str: """Return a fully-qualified annotated_dataset string.""" - return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) @staticmethod - def parse_annotated_dataset_path(path: str) -> Dict[str,str]: + def parse_annotated_dataset_path(path: str) -> Dict[str, str]: """Parse a annotated_dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,location: str,dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str,dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) @staticmethod - def parse_dataset_path(path: str) -> Dict[str,str]: + def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def version_path(project: str,model: str,version: str,) -> str: + def version_path(project: str, model: str, version: str,) -> str: """Return a fully-qualified version string.""" - return "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + return "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) @staticmethod - def parse_version_path(path: str) -> Dict[str,str]: + def parse_version_path(path: str) -> Dict[str, str]: """Parse a version path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, MigrationServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the migration service client. Args: @@ -326,7 +366,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -354,7 +396,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -366,8 +410,10 @@ def __init__(self, *, if isinstance(transport, MigrationServiceTransport): # transport is a MigrationServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -386,14 +432,15 @@ def __init__(self, *, client_info=client_info, ) - def search_migratable_resources(self, - request: migration_service.SearchMigratableResourcesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.SearchMigratableResourcesPager: + def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: r"""Searches all of the resources in automl.googleapis.com, datalabeling.googleapis.com and ml.googleapis.com that can be migrated to AI Platform's @@ -433,8 +480,10 @@ def search_migratable_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.SearchMigratableResourcesRequest. @@ -451,45 +500,40 @@ def search_migratable_resources(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.search_migratable_resources] + rpc = self._transport._wrapped_methods[ + self._transport.search_migratable_resources + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.SearchMigratableResourcesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def batch_migrate_resources(self, - request: migration_service.BatchMigrateResourcesRequest = None, - *, - parent: str = None, - migrate_resource_requests: Sequence[migration_service.MigrateResourceRequest] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Batch migrates resources from ml.googleapis.com, automl.googleapis.com, and datalabeling.googleapis.com to AI Platform (Unified). @@ -536,8 +580,10 @@ def batch_migrate_resources(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, migrate_resource_requests]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a migration_service.BatchMigrateResourcesRequest. @@ -562,21 +608,14 @@ def batch_migrate_resources(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. - response = ga_operation.from_gapic( + response = operation.from_gapic( response, self._transport.operations_client, migration_service.BatchMigrateResourcesResponse, @@ -587,21 +626,14 @@ def batch_migrate_resources(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'MigrationServiceClient', -) +__all__ = ("MigrationServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py index cc52903d15..826325f2e4 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -38,12 +38,15 @@ class SearchMigratableResourcesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., migration_service.SearchMigratableResourcesResponse], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +80,7 @@ def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: yield from page.migratable_resources def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class SearchMigratableResourcesAsyncPager: @@ -97,12 +100,17 @@ class SearchMigratableResourcesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[migration_service.SearchMigratableResourcesResponse]], - request: migration_service.SearchMigratableResourcesRequest, - response: migration_service.SearchMigratableResourcesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[migration_service.SearchMigratableResourcesResponse] + ], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +132,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + async def pages( + self, + ) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +150,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py index e42711db2e..af727857e7 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py @@ -25,12 +25,12 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] -_transport_registry['grpc'] = MigrationServiceGrpcTransport -_transport_registry['grpc_asyncio'] = MigrationServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = MigrationServiceGrpcTransport +_transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport __all__ = ( - 'MigrationServiceTransport', - 'MigrationServiceGrpcTransport', - 'MigrationServiceGrpcAsyncIOTransport', + "MigrationServiceTransport", + "MigrationServiceGrpcTransport", + "MigrationServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py index e48c2471f6..cbcb288489 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -33,29 +33,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class MigrationServiceTransport(abc.ABC): """Abstract transport class for MigrationService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -78,24 +78,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -116,7 +118,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -125,24 +126,25 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def search_migratable_resources(self) -> typing.Callable[ - [migration_service.SearchMigratableResourcesRequest], - typing.Union[ - migration_service.SearchMigratableResourcesResponse, - typing.Awaitable[migration_service.SearchMigratableResourcesResponse] - ]]: + def search_migratable_resources( + self, + ) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse], + ], + ]: raise NotImplementedError() @property - def batch_migrate_resources(self) -> typing.Callable[ - [migration_service.BatchMigrateResourcesRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def batch_migrate_resources( + self, + ) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'MigrationServiceTransport', -) +__all__ = ("MigrationServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index bf0e91b721..50d81c4ab3 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -47,20 +47,23 @@ class MigrationServiceGrpcTransport(MigrationServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -110,12 +113,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -140,7 +152,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -165,13 +179,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -204,7 +220,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -221,18 +237,21 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def search_migratable_resources(self) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - migration_service.SearchMigratableResourcesResponse]: + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse, + ]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -250,18 +269,20 @@ def search_migratable_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_migratable_resources' not in self._stubs: - self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs['search_migratable_resources'] + return self._stubs["search_migratable_resources"] @property - def batch_migrate_resources(self) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - operations.Operation]: + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], operations.Operation + ]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -278,15 +299,13 @@ def batch_migrate_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_migrate_resources' not in self._stubs: - self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_migrate_resources'] + return self._stubs["batch_migrate_resources"] -__all__ = ( - 'MigrationServiceGrpcTransport', -) +__all__ = ("MigrationServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index 3c12daf987..1450fbf2b5 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import migration_service @@ -54,13 +54,15 @@ class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -89,21 +91,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -154,12 +158,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -184,7 +197,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -226,18 +241,21 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def search_migratable_resources(self) -> Callable[ - [migration_service.SearchMigratableResourcesRequest], - Awaitable[migration_service.SearchMigratableResourcesResponse]]: + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse], + ]: r"""Return a callable for the search migratable resources method over gRPC. Searches all of the resources in @@ -255,18 +273,21 @@ def search_migratable_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'search_migratable_resources' not in self._stubs: - self._stubs['search_migratable_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources', + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, ) - return self._stubs['search_migratable_resources'] + return self._stubs["search_migratable_resources"] @property - def batch_migrate_resources(self) -> Callable[ - [migration_service.BatchMigrateResourcesRequest], - Awaitable[operations.Operation]]: + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the batch migrate resources method over gRPC. Batch migrates resources from ml.googleapis.com, @@ -283,15 +304,13 @@ def batch_migrate_resources(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_migrate_resources' not in self._stubs: - self._stubs['batch_migrate_resources'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources', + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['batch_migrate_resources'] + return self._stubs["batch_migrate_resources"] -__all__ = ( - 'MigrationServiceGrpcAsyncIOTransport', -) +__all__ = ("MigrationServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py index 3ee8fc6e9e..b39295ebfe 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import ModelServiceAsyncClient __all__ = ( - 'ModelServiceClient', - 'ModelServiceAsyncClient', + "ModelServiceClient", + "ModelServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index 1f35b4a15f..aa56f7d953 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -62,26 +62,44 @@ class ModelServiceAsyncClient: model_path = staticmethod(ModelServiceClient.model_path) parse_model_path = staticmethod(ModelServiceClient.parse_model_path) model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) - parse_model_evaluation_path = staticmethod(ModelServiceClient.parse_model_evaluation_path) - model_evaluation_slice_path = staticmethod(ModelServiceClient.model_evaluation_slice_path) - parse_model_evaluation_slice_path = staticmethod(ModelServiceClient.parse_model_evaluation_slice_path) + parse_model_evaluation_path = staticmethod( + ModelServiceClient.parse_model_evaluation_path + ) + model_evaluation_slice_path = staticmethod( + ModelServiceClient.model_evaluation_slice_path + ) + parse_model_evaluation_slice_path = staticmethod( + ModelServiceClient.parse_model_evaluation_slice_path + ) training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod(ModelServiceClient.parse_training_pipeline_path) + parse_training_pipeline_path = staticmethod( + ModelServiceClient.parse_training_pipeline_path + ) - common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + ModelServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(ModelServiceClient.common_folder_path) parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) common_organization_path = staticmethod(ModelServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + ModelServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(ModelServiceClient.common_project_path) - parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + ModelServiceClient.parse_common_project_path + ) common_location_path = staticmethod(ModelServiceClient.common_location_path) - parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + ModelServiceClient.parse_common_location_path + ) from_service_account_file = ModelServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -95,14 +113,18 @@ def transport(self) -> ModelServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) + get_transport_class = functools.partial( + type(ModelServiceClient).get_transport_class, type(ModelServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -141,18 +163,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Uploads a Model artifact into AI Platform. Args: @@ -193,8 +215,10 @@ async def upload_model(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, model]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.UploadModelRequest(request) @@ -217,18 +241,11 @@ async def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -241,14 +258,15 @@ async def upload_model(self, # Done; return the response. return response - async def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + async def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -276,8 +294,10 @@ async def get_model(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelRequest(request) @@ -298,30 +318,24 @@ async def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: + async def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: r"""Lists Models in a Location. Args: @@ -355,8 +369,10 @@ async def list_models(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelsRequest(request) @@ -377,40 +393,31 @@ async def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + async def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -446,8 +453,10 @@ async def update_model(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([model, update_mask]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.UpdateModelRequest(request) @@ -470,30 +479,26 @@ async def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -541,8 +546,10 @@ async def delete_model(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.DeleteModelRequest(request) @@ -563,18 +570,11 @@ async def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -587,15 +587,16 @@ async def delete_model(self, # Done; return the response. return response - async def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -640,8 +641,10 @@ async def export_model(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name, output_config]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ExportModelRequest(request) @@ -664,18 +667,11 @@ async def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -688,14 +684,15 @@ async def export_model(self, # Done; return the response. return response - async def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + async def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -729,8 +726,10 @@ async def get_model_evaluation(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelEvaluationRequest(request) @@ -751,30 +750,24 @@ async def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsAsyncPager: + async def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: r"""Lists ModelEvaluations in a Model. Args: @@ -808,8 +801,10 @@ async def list_model_evaluations(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelEvaluationsRequest(request) @@ -830,39 +825,30 @@ async def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + async def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -896,8 +882,10 @@ async def get_model_evaluation_slice(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelEvaluationSliceRequest(request) @@ -918,30 +906,24 @@ async def get_model_evaluation_slice(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesAsyncPager: + async def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -976,8 +958,10 @@ async def list_model_evaluation_slices(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelEvaluationSlicesRequest(request) @@ -998,47 +982,30 @@ async def list_model_evaluation_slices(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListModelEvaluationSlicesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceAsyncClient', -) +__all__ = ("ModelServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index cade034da4..30c00c0c9d 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -60,13 +60,12 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry['grpc'] = ModelServiceGrpcTransport - _transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[ModelServiceTransport]: + def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +116,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -136,9 +135,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -153,121 +151,162 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_path(project: str,location: str,model: str,evaluation: str,) -> str: + def model_evaluation_path( + project: str, location: str, model: str, evaluation: str, + ) -> str: """Return a fully-qualified model_evaluation string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) @staticmethod - def parse_model_evaluation_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_path(path: str) -> Dict[str, str]: """Parse a model_evaluation path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_evaluation_slice_path(project: str,location: str,model: str,evaluation: str,slice: str,) -> str: + def model_evaluation_slice_path( + project: str, location: str, model: str, evaluation: str, slice: str, + ) -> str: """Return a fully-qualified model_evaluation_slice string.""" - return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) @staticmethod - def parse_model_evaluation_slice_path(path: str) -> Dict[str,str]: + def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: """Parse a model_evaluation_slice path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the model service client. Args: @@ -311,7 +350,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -339,7 +380,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -351,8 +394,10 @@ def __init__(self, *, if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -371,15 +416,16 @@ def __init__(self, *, client_info=client_info, ) - def upload_model(self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Uploads a Model artifact into AI Platform. Args: @@ -421,8 +467,10 @@ def upload_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, model]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UploadModelRequest. @@ -446,18 +494,11 @@ def upload_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -470,14 +511,15 @@ def upload_model(self, # Done; return the response. return response - def get_model(self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets a Model. Args: @@ -506,8 +548,10 @@ def get_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -529,30 +573,24 @@ def get_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_models(self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists Models in a Location. Args: @@ -587,8 +625,10 @@ def list_models(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -610,40 +650,31 @@ def list_models(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def update_model(self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: + def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: r"""Updates a Model. Args: @@ -680,8 +711,10 @@ def update_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateModelRequest. @@ -705,30 +738,26 @@ def update_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('model.name', request.model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def delete_model(self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a Model. Note: Model can only be deleted if there are no DeployedModels created from it. @@ -777,8 +806,10 @@ def delete_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteModelRequest. @@ -800,18 +831,11 @@ def delete_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -824,15 +848,16 @@ def delete_model(self, # Done; return the response. return response - def export_model(self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Exports a trained, exportable, Model to a location specified by the user. A Model is considered to be exportable if it has at least one [supported export @@ -878,8 +903,10 @@ def export_model(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name, output_config]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ExportModelRequest. @@ -903,18 +930,11 @@ def export_model(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -927,14 +947,15 @@ def export_model(self, # Done; return the response. return response - def get_model_evaluation(self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: + def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: r"""Gets a ModelEvaluation. Args: @@ -969,8 +990,10 @@ def get_model_evaluation(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationRequest. @@ -992,30 +1015,24 @@ def get_model_evaluation(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluations(self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsPager: + def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: r"""Lists ModelEvaluations in a Model. Args: @@ -1050,8 +1067,10 @@ def list_model_evaluations(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationsRequest. @@ -1073,39 +1092,30 @@ def list_model_evaluations(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def get_model_evaluation_slice(self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: + def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: r"""Gets a ModelEvaluationSlice. Args: @@ -1140,8 +1150,10 @@ def get_model_evaluation_slice(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelEvaluationSliceRequest. @@ -1158,35 +1170,31 @@ def get_model_evaluation_slice(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation_slice] + rpc = self._transport._wrapped_methods[ + self._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_model_evaluation_slices(self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesPager: + def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: @@ -1222,8 +1230,10 @@ def list_model_evaluation_slices(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelEvaluationSlicesRequest. @@ -1240,52 +1250,37 @@ def list_model_evaluation_slices(self, # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_evaluation_slices] + rpc = self._transport._wrapped_methods[ + self._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'ModelServiceClient', -) +__all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index 716d790932..1ab3aacb91 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -40,12 +40,15 @@ class ListModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelsResponse], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -79,7 +82,7 @@ def __iter__(self) -> Iterable[model.Model]: yield from page.models def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelsAsyncPager: @@ -99,12 +102,15 @@ class ListModelsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -142,7 +148,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationsPager: @@ -162,12 +168,15 @@ class ListModelEvaluationsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelEvaluationsResponse], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationsResponse], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -201,7 +210,7 @@ def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: yield from page.model_evaluations def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationsAsyncPager: @@ -221,12 +230,15 @@ class ListModelEvaluationsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -264,7 +276,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesPager: @@ -284,12 +296,15 @@ class ListModelEvaluationSlicesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelEvaluationSlicesResponse], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -323,7 +338,7 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: yield from page.model_evaluation_slices def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelEvaluationSlicesAsyncPager: @@ -343,12 +358,17 @@ class ListModelEvaluationSlicesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationSlicesResponse]], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] + ], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -370,7 +390,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + async def pages( + self, + ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -386,4 +408,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py index 89bd6faee0..a521df9229 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py @@ -25,12 +25,12 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry['grpc'] = ModelServiceGrpcTransport -_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport __all__ = ( - 'ModelServiceTransport', - 'ModelServiceGrpcTransport', - 'ModelServiceGrpcAsyncIOTransport', + "ModelServiceTransport", + "ModelServiceGrpcTransport", + "ModelServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index d5f10a9943..681d035178 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -37,29 +37,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -82,24 +82,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -111,34 +113,22 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.upload_model: gapic_v1.method.wrap_method( - self.upload_model, - default_timeout=None, - client_info=client_info, + self.upload_model, default_timeout=None, client_info=client_info, ), self.get_model: gapic_v1.method.wrap_method( - self.get_model, - default_timeout=None, - client_info=client_info, + self.get_model, default_timeout=None, client_info=client_info, ), self.list_models: gapic_v1.method.wrap_method( - self.list_models, - default_timeout=None, - client_info=client_info, + self.list_models, default_timeout=None, client_info=client_info, ), self.update_model: gapic_v1.method.wrap_method( - self.update_model, - default_timeout=None, - client_info=client_info, + self.update_model, default_timeout=None, client_info=client_info, ), self.delete_model: gapic_v1.method.wrap_method( - self.delete_model, - default_timeout=None, - client_info=client_info, + self.delete_model, default_timeout=None, client_info=client_info, ), self.export_model: gapic_v1.method.wrap_method( - self.export_model, - default_timeout=None, - client_info=client_info, + self.export_model, default_timeout=None, client_info=client_info, ), self.get_model_evaluation: gapic_v1.method.wrap_method( self.get_model_evaluation, @@ -160,7 +150,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -169,96 +158,109 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def upload_model(self) -> typing.Callable[ - [model_service.UploadModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def upload_model( + self, + ) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model(self) -> typing.Callable[ - [model_service.GetModelRequest], - typing.Union[ - model.Model, - typing.Awaitable[model.Model] - ]]: + def get_model( + self, + ) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[model.Model, typing.Awaitable[model.Model]], + ]: raise NotImplementedError() @property - def list_models(self) -> typing.Callable[ - [model_service.ListModelsRequest], - typing.Union[ - model_service.ListModelsResponse, - typing.Awaitable[model_service.ListModelsResponse] - ]]: + def list_models( + self, + ) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse], + ], + ]: raise NotImplementedError() @property - def update_model(self) -> typing.Callable[ - [model_service.UpdateModelRequest], - typing.Union[ - gca_model.Model, - typing.Awaitable[gca_model.Model] - ]]: + def update_model( + self, + ) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], + ]: raise NotImplementedError() @property - def delete_model(self) -> typing.Callable[ - [model_service.DeleteModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_model( + self, + ) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def export_model(self) -> typing.Callable[ - [model_service.ExportModelRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def export_model( + self, + ) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_model_evaluation(self) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], - typing.Union[ - model_evaluation.ModelEvaluation, - typing.Awaitable[model_evaluation.ModelEvaluation] - ]]: + def get_model_evaluation( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation], + ], + ]: raise NotImplementedError() @property - def list_model_evaluations(self) -> typing.Callable[ - [model_service.ListModelEvaluationsRequest], - typing.Union[ - model_service.ListModelEvaluationsResponse, - typing.Awaitable[model_service.ListModelEvaluationsResponse] - ]]: + def list_model_evaluations( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse], + ], + ]: raise NotImplementedError() @property - def get_model_evaluation_slice(self) -> typing.Callable[ - [model_service.GetModelEvaluationSliceRequest], - typing.Union[ - model_evaluation_slice.ModelEvaluationSlice, - typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice] - ]]: + def get_model_evaluation_slice( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ], + ]: raise NotImplementedError() @property - def list_model_evaluation_slices(self) -> typing.Callable[ - [model_service.ListModelEvaluationSlicesRequest], - typing.Union[ - model_service.ListModelEvaluationSlicesResponse, - typing.Awaitable[model_service.ListModelEvaluationSlicesResponse] - ]]: + def list_model_evaluation_slices( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'ModelServiceTransport', -) +__all__ = ("ModelServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index 255d478e9d..df720617a7 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -49,20 +49,23 @@ class ModelServiceGrpcTransport(ModelServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -112,12 +115,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -142,7 +154,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -167,13 +181,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -206,7 +222,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -223,18 +239,18 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def upload_model(self) -> Callable[ - [model_service.UploadModelRequest], - operations.Operation]: + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], operations.Operation]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -249,18 +265,16 @@ def upload_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'upload_model' not in self._stubs: - self._stubs['upload_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['upload_model'] + return self._stubs["upload_model"] @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - model.Model]: + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -275,18 +289,18 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - model_service.ListModelsResponse]: + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -301,18 +315,18 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def update_model(self) -> Callable[ - [model_service.UpdateModelRequest], - gca_model.Model]: + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], gca_model.Model]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -327,18 +341,18 @@ def update_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model' not in self._stubs: - self._stubs['update_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs['update_model'] + return self._stubs["update_model"] @property - def delete_model(self) -> Callable[ - [model_service.DeleteModelRequest], - operations.Operation]: + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], operations.Operation]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -355,18 +369,18 @@ def delete_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model' not in self._stubs: - self._stubs['delete_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model'] + return self._stubs["delete_model"] @property - def export_model(self) -> Callable[ - [model_service.ExportModelRequest], - operations.Operation]: + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], operations.Operation]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -384,18 +398,20 @@ def export_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_model' not in self._stubs: - self._stubs['export_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_model'] + return self._stubs["export_model"] @property - def get_model_evaluation(self) -> Callable[ - [model_service.GetModelEvaluationRequest], - model_evaluation.ModelEvaluation]: + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation + ]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -410,18 +426,21 @@ def get_model_evaluation(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation' not in self._stubs: - self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs['get_model_evaluation'] + return self._stubs["get_model_evaluation"] @property - def list_model_evaluations(self) -> Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse]: + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse, + ]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -436,18 +455,21 @@ def list_model_evaluations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluations' not in self._stubs: - self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs['list_model_evaluations'] + return self._stubs["list_model_evaluations"] @property - def get_model_evaluation_slice(self) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice]: + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + model_evaluation_slice.ModelEvaluationSlice, + ]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -462,18 +484,21 @@ def get_model_evaluation_slice(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation_slice' not in self._stubs: - self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs['get_model_evaluation_slice'] + return self._stubs["get_model_evaluation_slice"] @property - def list_model_evaluation_slices(self) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse]: + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse, + ]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -488,15 +513,13 @@ def list_model_evaluation_slices(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluation_slices' not in self._stubs: - self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs['list_model_evaluation_slices'] + return self._stubs["list_model_evaluation_slices"] -__all__ = ( - 'ModelServiceGrpcTransport', -) +__all__ = ("ModelServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index 850a476d8f..ffe89774ef 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import model @@ -56,13 +56,15 @@ class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -91,21 +93,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -156,12 +160,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -186,7 +199,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -228,18 +243,18 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def upload_model(self) -> Callable[ - [model_service.UploadModelRequest], - Awaitable[operations.Operation]]: + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the upload model method over gRPC. Uploads a Model artifact into AI Platform. @@ -254,18 +269,18 @@ def upload_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'upload_model' not in self._stubs: - self._stubs['upload_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UploadModel', + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", request_serializer=model_service.UploadModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['upload_model'] + return self._stubs["upload_model"] @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - Awaitable[model.Model]]: + def get_model( + self, + ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: r"""Return a callable for the get model method over gRPC. Gets a Model. @@ -280,18 +295,20 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - Awaitable[model_service.ListModelsResponse]]: + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] + ]: r"""Return a callable for the list models method over gRPC. Lists Models in a Location. @@ -306,18 +323,18 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def update_model(self) -> Callable[ - [model_service.UpdateModelRequest], - Awaitable[gca_model.Model]]: + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: r"""Return a callable for the update model method over gRPC. Updates a Model. @@ -332,18 +349,18 @@ def update_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_model' not in self._stubs: - self._stubs['update_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel', + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", request_serializer=model_service.UpdateModelRequest.serialize, response_deserializer=gca_model.Model.deserialize, ) - return self._stubs['update_model'] + return self._stubs["update_model"] @property - def delete_model(self) -> Callable[ - [model_service.DeleteModelRequest], - Awaitable[operations.Operation]]: + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the delete model method over gRPC. Deletes a Model. @@ -360,18 +377,18 @@ def delete_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_model' not in self._stubs: - self._stubs['delete_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel', + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", request_serializer=model_service.DeleteModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_model'] + return self._stubs["delete_model"] @property - def export_model(self) -> Callable[ - [model_service.ExportModelRequest], - Awaitable[operations.Operation]]: + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: r"""Return a callable for the export model method over gRPC. Exports a trained, exportable, Model to a location specified by @@ -389,18 +406,21 @@ def export_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'export_model' not in self._stubs: - self._stubs['export_model'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ExportModel', + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", request_serializer=model_service.ExportModelRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['export_model'] + return self._stubs["export_model"] @property - def get_model_evaluation(self) -> Callable[ - [model_service.GetModelEvaluationRequest], - Awaitable[model_evaluation.ModelEvaluation]]: + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation], + ]: r"""Return a callable for the get model evaluation method over gRPC. Gets a ModelEvaluation. @@ -415,18 +435,21 @@ def get_model_evaluation(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation' not in self._stubs: - self._stubs['get_model_evaluation'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation', + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", request_serializer=model_service.GetModelEvaluationRequest.serialize, response_deserializer=model_evaluation.ModelEvaluation.deserialize, ) - return self._stubs['get_model_evaluation'] + return self._stubs["get_model_evaluation"] @property - def list_model_evaluations(self) -> Callable[ - [model_service.ListModelEvaluationsRequest], - Awaitable[model_service.ListModelEvaluationsResponse]]: + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse], + ]: r"""Return a callable for the list model evaluations method over gRPC. Lists ModelEvaluations in a Model. @@ -441,18 +464,21 @@ def list_model_evaluations(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluations' not in self._stubs: - self._stubs['list_model_evaluations'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations', + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", request_serializer=model_service.ListModelEvaluationsRequest.serialize, response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, ) - return self._stubs['list_model_evaluations'] + return self._stubs["list_model_evaluations"] @property - def get_model_evaluation_slice(self) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - Awaitable[model_evaluation_slice.ModelEvaluationSlice]]: + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ]: r"""Return a callable for the get model evaluation slice method over gRPC. Gets a ModelEvaluationSlice. @@ -467,18 +493,21 @@ def get_model_evaluation_slice(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model_evaluation_slice' not in self._stubs: - self._stubs['get_model_evaluation_slice'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice', + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, ) - return self._stubs['get_model_evaluation_slice'] + return self._stubs["get_model_evaluation_slice"] @property - def list_model_evaluation_slices(self) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - Awaitable[model_service.ListModelEvaluationSlicesResponse]]: + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse], + ]: r"""Return a callable for the list model evaluation slices method over gRPC. Lists ModelEvaluationSlices in a ModelEvaluation. @@ -493,15 +522,13 @@ def list_model_evaluation_slices(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_model_evaluation_slices' not in self._stubs: - self._stubs['list_model_evaluation_slices'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices', + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, ) - return self._stubs['list_model_evaluation_slices'] + return self._stubs["list_model_evaluation_slices"] -__all__ = ( - 'ModelServiceGrpcAsyncIOTransport', -) +__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py index f7f4d9b9ac..7f02b47358 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PipelineServiceAsyncClient __all__ = ( - 'PipelineServiceClient', - 'PipelineServiceAsyncClient', + "PipelineServiceClient", + "PipelineServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 6035fc4277..22777c2405 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -36,7 +36,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -60,22 +62,38 @@ class PipelineServiceAsyncClient: model_path = staticmethod(PipelineServiceClient.model_path) parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod(PipelineServiceClient.parse_training_pipeline_path) + parse_training_pipeline_path = staticmethod( + PipelineServiceClient.parse_training_pipeline_path + ) - common_billing_account_path = staticmethod(PipelineServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(PipelineServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + PipelineServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PipelineServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(PipelineServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + PipelineServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(PipelineServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(PipelineServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + PipelineServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PipelineServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(PipelineServiceClient.common_project_path) - parse_common_project_path = staticmethod(PipelineServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + PipelineServiceClient.parse_common_project_path + ) common_location_path = staticmethod(PipelineServiceClient.common_location_path) - parse_common_location_path = staticmethod(PipelineServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + PipelineServiceClient.parse_common_location_path + ) from_service_account_file = PipelineServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -89,14 +107,18 @@ def transport(self) -> PipelineServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient)) + get_transport_class = functools.partial( + type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -135,18 +157,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + async def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -188,8 +210,10 @@ async def create_training_pipeline(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, training_pipeline]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.CreateTrainingPipelineRequest(request) @@ -212,30 +236,24 @@ async def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + async def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -271,8 +289,10 @@ async def get_training_pipeline(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.GetTrainingPipelineRequest(request) @@ -293,30 +313,24 @@ async def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesAsyncPager: + async def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: r"""Lists TrainingPipelines in a Location. Args: @@ -350,8 +364,10 @@ async def list_training_pipelines(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.ListTrainingPipelinesRequest(request) @@ -372,39 +388,30 @@ async def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListTrainingPipelinesAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a TrainingPipeline. Args: @@ -451,8 +458,10 @@ async def delete_training_pipeline(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.DeleteTrainingPipelineRequest(request) @@ -473,18 +482,11 @@ async def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -497,14 +499,15 @@ async def delete_training_pipeline(self, # Done; return the response. return response - async def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -542,8 +545,10 @@ async def cancel_training_pipeline(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = pipeline_service.CancelTrainingPipelineRequest(request) @@ -564,35 +569,23 @@ async def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceAsyncClient', -) +__all__ = ("PipelineServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index fbecd0dc70..e3e7d6aeda 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -40,7 +40,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -58,13 +60,14 @@ class PipelineServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] - _transport_registry['grpc'] = PipelineServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PipelineServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry["grpc"] = PipelineServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. Args: @@ -115,7 +118,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -134,9 +137,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -151,99 +153,122 @@ def transport(self) -> PipelineServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def model_path(project: str,location: str,model: str,) -> str: + def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parse a model path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def training_pipeline_path(project: str,location: str,training_pipeline: str,) -> str: + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: """Return a fully-qualified training_pipeline string.""" - return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str,str]: + def parse_training_pipeline_path(path: str) -> Dict[str, str]: """Parse a training_pipeline path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the pipeline service client. Args: @@ -287,7 +312,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -315,7 +342,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -327,8 +356,10 @@ def __init__(self, *, if isinstance(transport, PipelineServiceTransport): # transport is a PipelineServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -347,15 +378,16 @@ def __init__(self, *, client_info=client_info, ) - def create_training_pipeline(self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: + def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: r"""Creates a TrainingPipeline. A created TrainingPipeline right away will be attempted to be run. @@ -398,8 +430,10 @@ def create_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, training_pipeline]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CreateTrainingPipelineRequest. @@ -423,30 +457,24 @@ def create_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def get_training_pipeline(self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: + def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: r"""Gets a TrainingPipeline. Args: @@ -483,8 +511,10 @@ def get_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.GetTrainingPipelineRequest. @@ -506,30 +536,24 @@ def get_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_training_pipelines(self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesPager: + def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: r"""Lists TrainingPipelines in a Location. Args: @@ -564,8 +588,10 @@ def list_training_pipelines(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.ListTrainingPipelinesRequest. @@ -587,39 +613,30 @@ def list_training_pipelines(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_training_pipeline(self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a TrainingPipeline. Args: @@ -667,8 +684,10 @@ def delete_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.DeleteTrainingPipelineRequest. @@ -690,18 +709,11 @@ def delete_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -714,14 +726,15 @@ def delete_training_pipeline(self, # Done; return the response. return response - def cancel_training_pipeline(self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on the TrainingPipeline. The server makes a best effort to cancel the pipeline, but success is not guaranteed. Clients can use @@ -760,8 +773,10 @@ def cancel_training_pipeline(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a pipeline_service.CancelTrainingPipelineRequest. @@ -783,35 +798,23 @@ def cancel_training_pipeline(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, + request, retry=retry, timeout=timeout, metadata=metadata, ) - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PipelineServiceClient', -) +__all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index beee148035..98e5a51a17 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -38,12 +38,15 @@ class ListTrainingPipelinesPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +80,7 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: yield from page.training_pipelines def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListTrainingPipelinesAsyncPager: @@ -97,12 +100,17 @@ class ListTrainingPipelinesAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[pipeline_service.ListTrainingPipelinesResponse]], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +132,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + async def pages( + self, + ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +150,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py index 3caa4c7906..d9d71a892b 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py @@ -25,12 +25,12 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] -_transport_registry['grpc'] = PipelineServiceGrpcTransport -_transport_registry['grpc_asyncio'] = PipelineServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = PipelineServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport __all__ = ( - 'PipelineServiceTransport', - 'PipelineServiceGrpcTransport', - 'PipelineServiceGrpcAsyncIOTransport', + "PipelineServiceTransport", + "PipelineServiceGrpcTransport", + "PipelineServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 0a74b8e8b6..1b235635f1 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -21,14 +21,16 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -36,29 +38,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -81,24 +83,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -134,7 +138,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -143,51 +146,58 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_training_pipeline(self) -> typing.Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - typing.Union[ - gca_training_pipeline.TrainingPipeline, - typing.Awaitable[gca_training_pipeline.TrainingPipeline] - ]]: + def create_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline], + ], + ]: raise NotImplementedError() @property - def get_training_pipeline(self) -> typing.Callable[ - [pipeline_service.GetTrainingPipelineRequest], - typing.Union[ - training_pipeline.TrainingPipeline, - typing.Awaitable[training_pipeline.TrainingPipeline] - ]]: + def get_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.GetTrainingPipelineRequest], + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline], + ], + ]: raise NotImplementedError() @property - def list_training_pipelines(self) -> typing.Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - typing.Union[ - pipeline_service.ListTrainingPipelinesResponse, - typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse] - ]]: + def list_training_pipelines( + self, + ) -> typing.Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ], + ]: raise NotImplementedError() @property - def delete_training_pipeline(self) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def cancel_training_pipeline(self) -> typing.Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - typing.Union[ - empty.Empty, - typing.Awaitable[empty.Empty] - ]]: + def cancel_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: raise NotImplementedError() -__all__ = ( - 'PipelineServiceTransport', -) +__all__ = ("PipelineServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 096505204b..66580ae42e 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -18,18 +18,20 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -48,20 +50,23 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -111,12 +116,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -141,7 +155,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -166,13 +182,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -205,7 +223,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -222,18 +240,21 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_training_pipeline(self) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline]: + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + gca_training_pipeline.TrainingPipeline, + ]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -249,18 +270,21 @@ def create_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_training_pipeline' not in self._stubs: - self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['create_training_pipeline'] + return self._stubs["create_training_pipeline"] @property - def get_training_pipeline(self) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline]: + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + training_pipeline.TrainingPipeline, + ]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -275,18 +299,21 @@ def get_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_training_pipeline' not in self._stubs: - self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['get_training_pipeline'] + return self._stubs["get_training_pipeline"] @property - def list_training_pipelines(self) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse]: + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse, + ]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -301,18 +328,20 @@ def list_training_pipelines(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_training_pipelines' not in self._stubs: - self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs['list_training_pipelines'] + return self._stubs["list_training_pipelines"] @property - def delete_training_pipeline(self) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - operations.Operation]: + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation + ]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -327,18 +356,18 @@ def delete_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_training_pipeline' not in self._stubs: - self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_training_pipeline'] + return self._stubs["delete_training_pipeline"] @property - def cancel_training_pipeline(self) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - empty.Empty]: + def cancel_training_pipeline( + self, + ) -> Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -365,15 +394,13 @@ def cancel_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_training_pipeline' not in self._stubs: - self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_training_pipeline'] + return self._stubs["cancel_training_pipeline"] -__all__ = ( - 'PipelineServiceGrpcTransport', -) +__all__ = ("PipelineServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index ce9bc0c191..a66285f6dc 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -18,19 +18,21 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore @@ -55,13 +57,15 @@ class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -90,21 +94,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -155,12 +161,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -185,7 +200,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -227,18 +244,21 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_training_pipeline(self) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - Awaitable[gca_training_pipeline.TrainingPipeline]]: + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline], + ]: r"""Return a callable for the create training pipeline method over gRPC. Creates a TrainingPipeline. A created @@ -254,18 +274,21 @@ def create_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_training_pipeline' not in self._stubs: - self._stubs['create_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline', + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['create_training_pipeline'] + return self._stubs["create_training_pipeline"] @property - def get_training_pipeline(self) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - Awaitable[training_pipeline.TrainingPipeline]]: + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline], + ]: r"""Return a callable for the get training pipeline method over gRPC. Gets a TrainingPipeline. @@ -280,18 +303,21 @@ def get_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_training_pipeline' not in self._stubs: - self._stubs['get_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline', + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, response_deserializer=training_pipeline.TrainingPipeline.deserialize, ) - return self._stubs['get_training_pipeline'] + return self._stubs["get_training_pipeline"] @property - def list_training_pipelines(self) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - Awaitable[pipeline_service.ListTrainingPipelinesResponse]]: + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ]: r"""Return a callable for the list training pipelines method over gRPC. Lists TrainingPipelines in a Location. @@ -306,18 +332,21 @@ def list_training_pipelines(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_training_pipelines' not in self._stubs: - self._stubs['list_training_pipelines'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines', + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, ) - return self._stubs['list_training_pipelines'] + return self._stubs["list_training_pipelines"] @property - def delete_training_pipeline(self) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - Awaitable[operations.Operation]]: + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete training pipeline method over gRPC. Deletes a TrainingPipeline. @@ -332,18 +361,20 @@ def delete_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_training_pipeline' not in self._stubs: - self._stubs['delete_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline', + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_training_pipeline'] + return self._stubs["delete_training_pipeline"] @property - def cancel_training_pipeline(self) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - Awaitable[empty.Empty]]: + def cancel_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] + ]: r"""Return a callable for the cancel training pipeline method over gRPC. Cancels a TrainingPipeline. Starts asynchronous cancellation on @@ -370,15 +401,13 @@ def cancel_training_pipeline(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'cancel_training_pipeline' not in self._stubs: - self._stubs['cancel_training_pipeline'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline', + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, response_deserializer=empty.Empty.FromString, ) - return self._stubs['cancel_training_pipeline'] + return self._stubs["cancel_training_pipeline"] -__all__ = ( - 'PipelineServiceGrpcAsyncIOTransport', -) +__all__ = ("PipelineServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py index d4047c335d..0c847693e0 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import PredictionServiceAsyncClient __all__ = ( - 'PredictionServiceClient', - 'PredictionServiceAsyncClient', + "PredictionServiceClient", + "PredictionServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index 283bb73f3e..606ce0f46b 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -48,20 +48,34 @@ class PredictionServiceAsyncClient: endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) - common_billing_account_path = staticmethod(PredictionServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(PredictionServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + PredictionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PredictionServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(PredictionServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + PredictionServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(PredictionServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(PredictionServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + PredictionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PredictionServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(PredictionServiceClient.common_project_path) - parse_common_project_path = staticmethod(PredictionServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + PredictionServiceClient.parse_common_project_path + ) common_location_path = staticmethod(PredictionServiceClient.common_location_path) - parse_common_location_path = staticmethod(PredictionServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + PredictionServiceClient.parse_common_location_path + ) from_service_account_file = PredictionServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -75,14 +89,18 @@ def transport(self) -> PredictionServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient)) + get_transport_class = functools.partial( + type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -121,19 +139,19 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def predict(self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + async def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -189,8 +207,10 @@ async def predict(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([endpoint, instances, parameters]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = prediction_service.PredictRequest(request) @@ -215,33 +235,27 @@ async def predict(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def explain(self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: + async def explain( + self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. If @@ -314,9 +328,13 @@ async def explain(self, # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, instances, parameters, deployed_model_id]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + if request is not None and any( + [endpoint, instances, parameters, deployed_model_id] + ): + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = prediction_service.ExplainRequest(request) @@ -343,38 +361,24 @@ async def explain(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PredictionServiceAsyncClient', -) +__all__ = ("PredictionServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 2627b20ae3..9a5976d697 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -48,13 +48,16 @@ class PredictionServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] - _transport_registry['grpc'] = PredictionServiceGrpcTransport - _transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[PredictionServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry["grpc"] = PredictionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[PredictionServiceTransport]: """Return an appropriate transport class. Args: @@ -105,7 +108,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -124,9 +127,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -141,77 +143,88 @@ def transport(self) -> PredictionServiceTransport: return self._transport @staticmethod - def endpoint_path(project: str,location: str,endpoint: str,) -> str: + def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" - return "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) @staticmethod - def parse_endpoint_path(path: str) -> Dict[str,str]: + def parse_endpoint_path(path: str) -> Dict[str, str]: """Parse a endpoint path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the prediction service client. Args: @@ -255,7 +268,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -283,7 +298,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -295,8 +312,10 @@ def __init__(self, *, if isinstance(transport, PredictionServiceTransport): # transport is a PredictionServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -315,16 +334,17 @@ def __init__(self, *, client_info=client_info, ) - def predict(self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: + def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: r"""Perform an online prediction. Args: @@ -381,8 +401,10 @@ def predict(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a prediction_service.PredictRequest. @@ -409,33 +431,27 @@ def predict(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def explain(self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: + def explain( + self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. If @@ -510,8 +526,10 @@ def explain(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a prediction_service.ExplainRequest. @@ -540,38 +558,24 @@ def explain(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('endpoint', request.endpoint), - )), + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'PredictionServiceClient', -) +__all__ = ("PredictionServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py index e130201fdf..7eb32ea86d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py @@ -25,12 +25,12 @@ # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] -_transport_registry['grpc'] = PredictionServiceGrpcTransport -_transport_registry['grpc_asyncio'] = PredictionServiceGrpcAsyncIOTransport +_transport_registry["grpc"] = PredictionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport __all__ = ( - 'PredictionServiceTransport', - 'PredictionServiceGrpcTransport', - 'PredictionServiceGrpcAsyncIOTransport', + "PredictionServiceTransport", + "PredictionServiceGrpcTransport", + "PredictionServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 86e2292130..cdec1c11e5 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore @@ -31,29 +31,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -76,24 +76,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -105,37 +107,36 @@ def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.predict: gapic_v1.method.wrap_method( - self.predict, - default_timeout=None, - client_info=client_info, + self.predict, default_timeout=None, client_info=client_info, ), self.explain: gapic_v1.method.wrap_method( - self.explain, - default_timeout=None, - client_info=client_info, + self.explain, default_timeout=None, client_info=client_info, ), - } @property - def predict(self) -> typing.Callable[ - [prediction_service.PredictRequest], - typing.Union[ - prediction_service.PredictResponse, - typing.Awaitable[prediction_service.PredictResponse] - ]]: + def predict( + self, + ) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse], + ], + ]: raise NotImplementedError() @property - def explain(self) -> typing.Callable[ - [prediction_service.ExplainRequest], - typing.Union[ - prediction_service.ExplainResponse, - typing.Awaitable[prediction_service.ExplainResponse] - ]]: + def explain( + self, + ) -> typing.Callable[ + [prediction_service.ExplainRequest], + typing.Union[ + prediction_service.ExplainResponse, + typing.Awaitable[prediction_service.ExplainResponse], + ], + ]: raise NotImplementedError() -__all__ = ( - 'PredictionServiceTransport', -) +__all__ = ("PredictionServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 520120cfa3..6c4cdf8d12 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -18,10 +18,10 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -43,20 +43,23 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -106,12 +109,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -136,7 +148,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -161,13 +175,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -200,7 +216,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -210,9 +226,11 @@ def grpc_channel(self) -> grpc.Channel: return self._grpc_channel @property - def predict(self) -> Callable[ - [prediction_service.PredictRequest], - prediction_service.PredictResponse]: + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], prediction_service.PredictResponse + ]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -227,18 +245,20 @@ def predict(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'predict' not in self._stubs: - self._stubs['predict'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs['predict'] + return self._stubs["predict"] @property - def explain(self) -> Callable[ - [prediction_service.ExplainRequest], - prediction_service.ExplainResponse]: + def explain( + self, + ) -> Callable[ + [prediction_service.ExplainRequest], prediction_service.ExplainResponse + ]: r"""Return a callable for the explain method over gRPC. Perform an online explanation. @@ -264,15 +284,13 @@ def explain(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'explain' not in self._stubs: - self._stubs['explain'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', + if "explain" not in self._stubs: + self._stubs["explain"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", request_serializer=prediction_service.ExplainRequest.serialize, response_deserializer=prediction_service.ExplainResponse.deserialize, ) - return self._stubs['explain'] + return self._stubs["explain"] -__all__ = ( - 'PredictionServiceGrpcTransport', -) +__all__ = ("PredictionServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index 1a1d48b450..f8d06bc047 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -18,13 +18,13 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -50,13 +50,15 @@ class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -85,21 +87,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -150,12 +154,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -180,7 +193,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -215,9 +230,12 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def predict(self) -> Callable[ - [prediction_service.PredictRequest], - Awaitable[prediction_service.PredictResponse]]: + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse], + ]: r"""Return a callable for the predict method over gRPC. Perform an online prediction. @@ -232,18 +250,21 @@ def predict(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'predict' not in self._stubs: - self._stubs['predict'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Predict', + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", request_serializer=prediction_service.PredictRequest.serialize, response_deserializer=prediction_service.PredictResponse.deserialize, ) - return self._stubs['predict'] + return self._stubs["predict"] @property - def explain(self) -> Callable[ - [prediction_service.ExplainRequest], - Awaitable[prediction_service.ExplainResponse]]: + def explain( + self, + ) -> Callable[ + [prediction_service.ExplainRequest], + Awaitable[prediction_service.ExplainResponse], + ]: r"""Return a callable for the explain method over gRPC. Perform an online explanation. @@ -269,15 +290,13 @@ def explain(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'explain' not in self._stubs: - self._stubs['explain'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.PredictionService/Explain', + if "explain" not in self._stubs: + self._stubs["explain"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", request_serializer=prediction_service.ExplainRequest.serialize, response_deserializer=prediction_service.ExplainResponse.deserialize, ) - return self._stubs['explain'] + return self._stubs["explain"] -__all__ = ( - 'PredictionServiceGrpcAsyncIOTransport', -) +__all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py index e4247d7758..49e9cdf0a0 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py @@ -19,6 +19,6 @@ from .async_client import SpecialistPoolServiceAsyncClient __all__ = ( - 'SpecialistPoolServiceClient', - 'SpecialistPoolServiceAsyncClient', + "SpecialistPoolServiceClient", + "SpecialistPoolServiceAsyncClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index c4ea8855c1..507ce92262 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -21,12 +21,12 @@ from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -57,23 +57,43 @@ class SpecialistPoolServiceAsyncClient: DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT - specialist_pool_path = staticmethod(SpecialistPoolServiceClient.specialist_pool_path) - parse_specialist_pool_path = staticmethod(SpecialistPoolServiceClient.parse_specialist_pool_path) + specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.specialist_pool_path + ) + parse_specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.parse_specialist_pool_path + ) - common_billing_account_path = staticmethod(SpecialistPoolServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(SpecialistPoolServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(SpecialistPoolServiceClient.parse_common_folder_path) + parse_common_folder_path = staticmethod( + SpecialistPoolServiceClient.parse_common_folder_path + ) - common_organization_path = staticmethod(SpecialistPoolServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(SpecialistPoolServiceClient.parse_common_organization_path) + common_organization_path = staticmethod( + SpecialistPoolServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + SpecialistPoolServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) - parse_common_project_path = staticmethod(SpecialistPoolServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + SpecialistPoolServiceClient.parse_common_project_path + ) - common_location_path = staticmethod(SpecialistPoolServiceClient.common_location_path) - parse_common_location_path = staticmethod(SpecialistPoolServiceClient.parse_common_location_path) + common_location_path = staticmethod( + SpecialistPoolServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + SpecialistPoolServiceClient.parse_common_location_path + ) from_service_account_file = SpecialistPoolServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -87,14 +107,19 @@ def transport(self) -> SpecialistPoolServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(SpecialistPoolServiceClient).get_transport_class, type(SpecialistPoolServiceClient)) + get_transport_class = functools.partial( + type(SpecialistPoolServiceClient).get_transport_class, + type(SpecialistPoolServiceClient), + ) - def __init__(self, *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = 'grpc_asyncio', - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -133,18 +158,18 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a SpecialistPool. Args: @@ -191,8 +216,10 @@ async def create_specialist_pool(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent, specialist_pool]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.CreateSpecialistPoolRequest(request) @@ -215,18 +242,11 @@ async def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -239,14 +259,15 @@ async def create_specialist_pool(self, # Done; return the response. return response - async def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + async def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -287,8 +308,10 @@ async def get_specialist_pool(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.GetSpecialistPoolRequest(request) @@ -309,30 +332,24 @@ async def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - async def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsAsyncPager: + async def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: r"""Lists SpecialistPools in a Location. Args: @@ -366,8 +383,10 @@ async def list_specialist_pools(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([parent]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.ListSpecialistPoolsRequest(request) @@ -388,39 +407,30 @@ async def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__aiter__` convenience method. response = pagers.ListSpecialistPoolsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - async def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -467,8 +477,10 @@ async def delete_specialist_pool(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([name]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.DeleteSpecialistPoolRequest(request) @@ -489,18 +501,11 @@ async def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -513,15 +518,16 @@ async def delete_specialist_pool(self, # Done; return the response. return response - async def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Updates a SpecialistPool. Args: @@ -567,8 +573,10 @@ async def update_specialist_pool(self, # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. if request is not None and any([specialist_pool, update_mask]): - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = specialist_pool_service.UpdateSpecialistPoolRequest(request) @@ -591,18 +599,13 @@ async def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = operation_async.from_gapic( @@ -616,21 +619,14 @@ async def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceAsyncClient', -) +__all__ = ("SpecialistPoolServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index f6938e8d1f..efc19eca12 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -23,14 +23,14 @@ import pkg_resources from google.api_core import client_options as client_options_lib # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore @@ -54,13 +54,16 @@ class SpecialistPoolServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] - _transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport - _transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport - def get_transport_class(cls, - label: str = None, - ) -> Type[SpecialistPoolServiceTransport]: + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport + _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: """Return an appropriate transport class. Args: @@ -117,7 +120,7 @@ def _get_default_mtls_endpoint(api_endpoint): return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - DEFAULT_ENDPOINT = 'aiplatform.googleapis.com' + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore DEFAULT_ENDPOINT ) @@ -136,9 +139,8 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: {@api.name}: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs['credentials'] = credentials + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials return cls(*args, **kwargs) from_service_account_json = from_service_account_file @@ -153,77 +155,88 @@ def transport(self) -> SpecialistPoolServiceTransport: return self._transport @staticmethod - def specialist_pool_path(project: str,location: str,specialist_pool: str,) -> str: + def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" - return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str,str]: + def parse_specialist_pool_path(path: str) -> Dict[str, str]: """Parse a specialist_pool path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path(folder: str,) -> str: """Return a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format(folder=folder,) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path(organization: str,) -> str: """Return a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format(organization=organization,) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path(project: str,) -> str: """Return a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format(project=project,) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path(project: str, location: str,) -> str: """Return a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} - def __init__(self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the specialist pool service client. Args: @@ -267,7 +280,9 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false"))) + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) ssl_credentials = None is_mtls = False @@ -295,7 +310,9 @@ def __init__(self, *, elif use_mtls_env == "always": api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) else: raise MutualTLSChannelError( "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" @@ -307,8 +324,10 @@ def __init__(self, *, if isinstance(transport, SpecialistPoolServiceTransport): # transport is a SpecialistPoolServiceTransport instance. if credentials or client_options.credentials_file: - raise ValueError('When providing a transport instance, ' - 'provide its credentials directly.') + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, " @@ -327,15 +346,16 @@ def __init__(self, *, client_info=client_info, ) - def create_specialist_pool(self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Creates a SpecialistPool. Args: @@ -383,8 +403,10 @@ def create_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, specialist_pool]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.CreateSpecialistPoolRequest. @@ -408,18 +430,11 @@ def create_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -432,14 +447,15 @@ def create_specialist_pool(self, # Done; return the response. return response - def get_specialist_pool(self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: + def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: r"""Gets a SpecialistPool. Args: @@ -481,8 +497,10 @@ def get_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.GetSpecialistPoolRequest. @@ -504,30 +522,24 @@ def get_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Done; return the response. return response - def list_specialist_pools(self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsPager: + def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: r"""Lists SpecialistPools in a Location. Args: @@ -562,8 +574,10 @@ def list_specialist_pools(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.ListSpecialistPoolsRequest. @@ -585,39 +599,30 @@ def list_specialist_pools(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. return response - def delete_specialist_pool(self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Deletes a SpecialistPool as well as all Specialists in the pool. @@ -665,8 +670,10 @@ def delete_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.DeleteSpecialistPoolRequest. @@ -688,18 +695,11 @@ def delete_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('name', request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -712,15 +712,16 @@ def delete_specialist_pool(self, # Done; return the response. return response - def update_specialist_pool(self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> ga_operation.Operation: + def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: r"""Updates a SpecialistPool. Args: @@ -767,8 +768,10 @@ def update_specialist_pool(self, # gotten any keyword arguments that map to the request. has_flattened_params = any([specialist_pool, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a specialist_pool_service.UpdateSpecialistPoolRequest. @@ -792,18 +795,13 @@ def update_specialist_pool(self, # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('specialist_pool.name', request.specialist_pool.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) # Wrap the response in an operation future. response = ga_operation.from_gapic( @@ -817,21 +815,14 @@ def update_specialist_pool(self, return response - - - - - try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() -__all__ = ( - 'SpecialistPoolServiceClient', -) +__all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index 68093dbff5..ff2d84ac74 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -38,12 +38,15 @@ class ListSpecialistPoolsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -77,7 +80,7 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: yield from page.specialist_pools def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListSpecialistPoolsAsyncPager: @@ -97,12 +100,17 @@ class ListSpecialistPoolsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[ + ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -124,7 +132,9 @@ def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @property - async def pages(self) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + async def pages( + self, + ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token @@ -140,4 +150,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py index ed5bf01517..711f7fd1cc 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py @@ -24,13 +24,15 @@ # Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[SpecialistPoolServiceTransport]] -_transport_registry['grpc'] = SpecialistPoolServiceGrpcTransport -_transport_registry['grpc_asyncio'] = SpecialistPoolServiceGrpcAsyncIOTransport +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[SpecialistPoolServiceTransport]] +_transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport +_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( - 'SpecialistPoolServiceTransport', - 'SpecialistPoolServiceGrpcTransport', - 'SpecialistPoolServiceGrpcAsyncIOTransport', + "SpecialistPoolServiceTransport", + "SpecialistPoolServiceGrpcTransport", + "SpecialistPoolServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index 20c4d1cf3c..30fbd3030f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -21,7 +21,7 @@ from google import auth # type: ignore from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,29 +34,29 @@ try: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( - 'google-cloud-aiplatform', + "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" - AUTH_SCOPES = ( - 'https://www.googleapis.com/auth/cloud-platform', - ) + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) def __init__( - self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, - ) -> None: + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -79,24 +79,26 @@ def __init__( your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, - scopes=scopes, - quota_project_id=quota_project_id - ) + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = auth.default(scopes=scopes, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials @@ -113,9 +115,7 @@ def _prep_wrapped_messages(self, client_info): client_info=client_info, ), self.get_specialist_pool: gapic_v1.method.wrap_method( - self.get_specialist_pool, - default_timeout=None, - client_info=client_info, + self.get_specialist_pool, default_timeout=None, client_info=client_info, ), self.list_specialist_pools: gapic_v1.method.wrap_method( self.list_specialist_pools, @@ -132,7 +132,6 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } @property @@ -141,51 +140,55 @@ def operations_client(self) -> operations_v1.OperationsClient: raise NotImplementedError() @property - def create_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def create_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def get_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - typing.Union[ - specialist_pool.SpecialistPool, - typing.Awaitable[specialist_pool.SpecialistPool] - ]]: + def get_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool], + ], + ]: raise NotImplementedError() @property - def list_specialist_pools(self) -> typing.Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - typing.Union[ - specialist_pool_service.ListSpecialistPoolsResponse, - typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] - ]]: + def list_specialist_pools( + self, + ) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ], + ]: raise NotImplementedError() @property - def delete_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def delete_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() @property - def update_specialist_pool(self) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - typing.Union[ - operations.Operation, - typing.Awaitable[operations.Operation] - ]]: + def update_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: raise NotImplementedError() -__all__ = ( - 'SpecialistPoolServiceTransport', -) +__all__ = ("SpecialistPoolServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 071a58862f..18bdaaa035 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -18,11 +18,11 @@ import warnings from typing import Callable, Dict, Optional, Sequence, Tuple -from google.api_core import grpc_helpers # type: ignore +from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -51,20 +51,23 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -114,12 +117,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -144,7 +156,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -169,13 +183,15 @@ def __init__(self, *, ) @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: address (Optionsl[str]): The host for the channel to use. @@ -208,7 +224,7 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property @@ -225,18 +241,20 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_specialist_pool(self) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - operations.Operation]: + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -251,18 +269,21 @@ def create_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_specialist_pool' not in self._stubs: - self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_specialist_pool'] + return self._stubs["create_specialist_pool"] @property - def get_specialist_pool(self) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool]: + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + specialist_pool.SpecialistPool, + ]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -277,18 +298,21 @@ def get_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_specialist_pool' not in self._stubs: - self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs['get_specialist_pool'] + return self._stubs["get_specialist_pool"] @property - def list_specialist_pools(self) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse]: + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse, + ]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -303,18 +327,20 @@ def list_specialist_pools(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_specialist_pools' not in self._stubs: - self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs['list_specialist_pools'] + return self._stubs["list_specialist_pools"] @property - def delete_specialist_pool(self) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - operations.Operation]: + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -330,18 +356,20 @@ def delete_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_specialist_pool' not in self._stubs: - self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_specialist_pool'] + return self._stubs["delete_specialist_pool"] @property - def update_specialist_pool(self) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - operations.Operation]: + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation + ]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -356,15 +384,13 @@ def update_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_specialist_pool' not in self._stubs: - self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_specialist_pool'] + return self._stubs["update_specialist_pool"] -__all__ = ( - 'SpecialistPoolServiceGrpcTransport', -) +__all__ = ("SpecialistPoolServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index 68639540e7..e2763c647f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -18,14 +18,14 @@ import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore -import grpc # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.cloud.aiplatform_v1beta1.types import specialist_pool @@ -58,13 +58,15 @@ class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: address (Optional[str]): The host for the channel to use. @@ -93,21 +95,23 @@ def create_channel(cls, credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'aiplatform.googleapis.com', - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiate the transport. Args: @@ -158,12 +162,21 @@ def __init__(self, *, # If a channel was explicitly provided, set it. self._grpc_channel = channel elif api_mtls_endpoint: - warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) - host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # Create SSL credentials with client_cert_source or application # default SSL credentials. @@ -188,7 +201,9 @@ def __init__(self, *, host = host if ":" in host else host + ":443" if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id) + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( @@ -230,18 +245,21 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if 'operations_client' not in self.__dict__: - self.__dict__['operations_client'] = operations_v1.OperationsAsyncClient( + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__['operations_client'] + return self.__dict__["operations_client"] @property - def create_specialist_pool(self) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the create specialist pool method over gRPC. Creates a SpecialistPool. @@ -256,18 +274,21 @@ def create_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_specialist_pool' not in self._stubs: - self._stubs['create_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool', + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['create_specialist_pool'] + return self._stubs["create_specialist_pool"] @property - def get_specialist_pool(self) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - Awaitable[specialist_pool.SpecialistPool]]: + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool], + ]: r"""Return a callable for the get specialist pool method over gRPC. Gets a SpecialistPool. @@ -282,18 +303,21 @@ def get_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_specialist_pool' not in self._stubs: - self._stubs['get_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool', + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, response_deserializer=specialist_pool.SpecialistPool.deserialize, ) - return self._stubs['get_specialist_pool'] + return self._stubs["get_specialist_pool"] @property - def list_specialist_pools(self) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - Awaitable[specialist_pool_service.ListSpecialistPoolsResponse]]: + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ]: r"""Return a callable for the list specialist pools method over gRPC. Lists SpecialistPools in a Location. @@ -308,18 +332,21 @@ def list_specialist_pools(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_specialist_pools' not in self._stubs: - self._stubs['list_specialist_pools'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools', + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, ) - return self._stubs['list_specialist_pools'] + return self._stubs["list_specialist_pools"] @property - def delete_specialist_pool(self) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the delete specialist pool method over gRPC. Deletes a SpecialistPool as well as all Specialists @@ -335,18 +362,21 @@ def delete_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_specialist_pool' not in self._stubs: - self._stubs['delete_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool', + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['delete_specialist_pool'] + return self._stubs["delete_specialist_pool"] @property - def update_specialist_pool(self) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - Awaitable[operations.Operation]]: + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: r"""Return a callable for the update specialist pool method over gRPC. Updates a SpecialistPool. @@ -361,15 +391,13 @@ def update_specialist_pool(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_specialist_pool' not in self._stubs: - self._stubs['update_specialist_pool'] = self.grpc_channel.unary_unary( - '/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool', + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, response_deserializer=operations.Operation.FromString, ) - return self._stubs['update_specialist_pool'] + return self._stubs["update_specialist_pool"] -__all__ = ( - 'SpecialistPoolServiceGrpcAsyncIOTransport', -) +__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 08cb2d804e..82fa939f8c 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -15,199 +15,361 @@ # limitations under the License. # -from .annotation_spec import (AnnotationSpec, ) -from .io import (GcsSource, GcsDestination, BigQuerySource, BigQueryDestination, ContainerRegistryDestination, ) -from .dataset import (Dataset, ImportDataConfig, ExportDataConfig, ) -from .manual_batch_tuning_parameters import (ManualBatchTuningParameters, ) -from .completion_stats import (CompletionStats, ) -from .model_evaluation_slice import (ModelEvaluationSlice, ) -from .machine_resources import (MachineSpec, DedicatedResources, AutomaticResources, BatchDedicatedResources, ResourcesConsumed, ) -from .deployed_model_ref import (DeployedModelRef, ) -from .env_var import (EnvVar, ) -from .explanation_metadata import (ExplanationMetadata, ) -from .explanation import (Explanation, ModelExplanation, Attribution, ExplanationSpec, ExplanationParameters, SampledShapleyAttribution, ) -from .model import (Model, PredictSchemata, ModelContainerSpec, Port, ) -from .training_pipeline import (TrainingPipeline, InputDataConfig, FractionSplit, FilterSplit, PredefinedSplit, TimestampSplit, ) -from .model_evaluation import (ModelEvaluation, ) -from .migratable_resource import (MigratableResource, ) -from .operation import (GenericOperationMetadata, DeleteOperationMetadata, ) -from .migration_service import (SearchMigratableResourcesRequest, SearchMigratableResourcesResponse, BatchMigrateResourcesRequest, MigrateResourceRequest, BatchMigrateResourcesResponse, MigrateResourceResponse, BatchMigrateResourcesOperationMetadata, ) -from .batch_prediction_job import (BatchPredictionJob, ) -from .custom_job import (CustomJob, CustomJobSpec, WorkerPoolSpec, ContainerSpec, PythonPackageSpec, Scheduling, ) -from .specialist_pool import (SpecialistPool, ) -from .data_labeling_job import (DataLabelingJob, ActiveLearningConfig, SampleConfig, TrainingConfig, ) -from .study import (Trial, StudySpec, Measurement, ) -from .hyperparameter_tuning_job import (HyperparameterTuningJob, ) -from .job_service import (CreateCustomJobRequest, GetCustomJobRequest, ListCustomJobsRequest, ListCustomJobsResponse, DeleteCustomJobRequest, CancelCustomJobRequest, CreateDataLabelingJobRequest, GetDataLabelingJobRequest, ListDataLabelingJobsRequest, ListDataLabelingJobsResponse, DeleteDataLabelingJobRequest, CancelDataLabelingJobRequest, CreateHyperparameterTuningJobRequest, GetHyperparameterTuningJobRequest, ListHyperparameterTuningJobsRequest, ListHyperparameterTuningJobsResponse, DeleteHyperparameterTuningJobRequest, CancelHyperparameterTuningJobRequest, CreateBatchPredictionJobRequest, GetBatchPredictionJobRequest, ListBatchPredictionJobsRequest, ListBatchPredictionJobsResponse, DeleteBatchPredictionJobRequest, CancelBatchPredictionJobRequest, ) -from .user_action_reference import (UserActionReference, ) -from .annotation import (Annotation, ) -from .endpoint import (Endpoint, DeployedModel, ) -from .prediction_service import (PredictRequest, PredictResponse, ExplainRequest, ExplainResponse, ) -from .endpoint_service import (CreateEndpointRequest, CreateEndpointOperationMetadata, GetEndpointRequest, ListEndpointsRequest, ListEndpointsResponse, UpdateEndpointRequest, DeleteEndpointRequest, DeployModelRequest, DeployModelResponse, DeployModelOperationMetadata, UndeployModelRequest, UndeployModelResponse, UndeployModelOperationMetadata, ) -from .pipeline_service import (CreateTrainingPipelineRequest, GetTrainingPipelineRequest, ListTrainingPipelinesRequest, ListTrainingPipelinesResponse, DeleteTrainingPipelineRequest, CancelTrainingPipelineRequest, ) -from .specialist_pool_service import (CreateSpecialistPoolRequest, CreateSpecialistPoolOperationMetadata, GetSpecialistPoolRequest, ListSpecialistPoolsRequest, ListSpecialistPoolsResponse, DeleteSpecialistPoolRequest, UpdateSpecialistPoolRequest, UpdateSpecialistPoolOperationMetadata, ) -from .model_service import (UploadModelRequest, UploadModelOperationMetadata, UploadModelResponse, GetModelRequest, ListModelsRequest, ListModelsResponse, UpdateModelRequest, DeleteModelRequest, ExportModelRequest, ExportModelOperationMetadata, ExportModelResponse, GetModelEvaluationRequest, ListModelEvaluationsRequest, ListModelEvaluationsResponse, GetModelEvaluationSliceRequest, ListModelEvaluationSlicesRequest, ListModelEvaluationSlicesResponse, ) -from .data_item import (DataItem, ) -from .dataset_service import (CreateDatasetRequest, CreateDatasetOperationMetadata, GetDatasetRequest, UpdateDatasetRequest, ListDatasetsRequest, ListDatasetsResponse, DeleteDatasetRequest, ImportDataRequest, ImportDataResponse, ImportDataOperationMetadata, ExportDataRequest, ExportDataResponse, ExportDataOperationMetadata, ListDataItemsRequest, ListDataItemsResponse, GetAnnotationSpecRequest, ListAnnotationsRequest, ListAnnotationsResponse, ) +from .annotation_spec import AnnotationSpec +from .io import ( + GcsSource, + GcsDestination, + BigQuerySource, + BigQueryDestination, + ContainerRegistryDestination, +) +from .dataset import ( + Dataset, + ImportDataConfig, + ExportDataConfig, +) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters +from .completion_stats import CompletionStats +from .model_evaluation_slice import ModelEvaluationSlice +from .machine_resources import ( + MachineSpec, + DedicatedResources, + AutomaticResources, + BatchDedicatedResources, + ResourcesConsumed, +) +from .deployed_model_ref import DeployedModelRef +from .env_var import EnvVar +from .explanation_metadata import ExplanationMetadata +from .explanation import ( + Explanation, + ModelExplanation, + Attribution, + ExplanationSpec, + ExplanationParameters, + SampledShapleyAttribution, +) +from .model import ( + Model, + PredictSchemata, + ModelContainerSpec, + Port, +) +from .training_pipeline import ( + TrainingPipeline, + InputDataConfig, + FractionSplit, + FilterSplit, + PredefinedSplit, + TimestampSplit, +) +from .model_evaluation import ModelEvaluation +from .migratable_resource import MigratableResource +from .operation import ( + GenericOperationMetadata, + DeleteOperationMetadata, +) +from .migration_service import ( + SearchMigratableResourcesRequest, + SearchMigratableResourcesResponse, + BatchMigrateResourcesRequest, + MigrateResourceRequest, + BatchMigrateResourcesResponse, + MigrateResourceResponse, + BatchMigrateResourcesOperationMetadata, +) +from .batch_prediction_job import BatchPredictionJob +from .custom_job import ( + CustomJob, + CustomJobSpec, + WorkerPoolSpec, + ContainerSpec, + PythonPackageSpec, + Scheduling, +) +from .specialist_pool import SpecialistPool +from .data_labeling_job import ( + DataLabelingJob, + ActiveLearningConfig, + SampleConfig, + TrainingConfig, +) +from .study import ( + Trial, + StudySpec, + Measurement, +) +from .hyperparameter_tuning_job import HyperparameterTuningJob +from .job_service import ( + CreateCustomJobRequest, + GetCustomJobRequest, + ListCustomJobsRequest, + ListCustomJobsResponse, + DeleteCustomJobRequest, + CancelCustomJobRequest, + CreateDataLabelingJobRequest, + GetDataLabelingJobRequest, + ListDataLabelingJobsRequest, + ListDataLabelingJobsResponse, + DeleteDataLabelingJobRequest, + CancelDataLabelingJobRequest, + CreateHyperparameterTuningJobRequest, + GetHyperparameterTuningJobRequest, + ListHyperparameterTuningJobsRequest, + ListHyperparameterTuningJobsResponse, + DeleteHyperparameterTuningJobRequest, + CancelHyperparameterTuningJobRequest, + CreateBatchPredictionJobRequest, + GetBatchPredictionJobRequest, + ListBatchPredictionJobsRequest, + ListBatchPredictionJobsResponse, + DeleteBatchPredictionJobRequest, + CancelBatchPredictionJobRequest, +) +from .user_action_reference import UserActionReference +from .annotation import Annotation +from .endpoint import ( + Endpoint, + DeployedModel, +) +from .prediction_service import ( + PredictRequest, + PredictResponse, + ExplainRequest, + ExplainResponse, +) +from .endpoint_service import ( + CreateEndpointRequest, + CreateEndpointOperationMetadata, + GetEndpointRequest, + ListEndpointsRequest, + ListEndpointsResponse, + UpdateEndpointRequest, + DeleteEndpointRequest, + DeployModelRequest, + DeployModelResponse, + DeployModelOperationMetadata, + UndeployModelRequest, + UndeployModelResponse, + UndeployModelOperationMetadata, +) +from .pipeline_service import ( + CreateTrainingPipelineRequest, + GetTrainingPipelineRequest, + ListTrainingPipelinesRequest, + ListTrainingPipelinesResponse, + DeleteTrainingPipelineRequest, + CancelTrainingPipelineRequest, +) +from .specialist_pool_service import ( + CreateSpecialistPoolRequest, + CreateSpecialistPoolOperationMetadata, + GetSpecialistPoolRequest, + ListSpecialistPoolsRequest, + ListSpecialistPoolsResponse, + DeleteSpecialistPoolRequest, + UpdateSpecialistPoolRequest, + UpdateSpecialistPoolOperationMetadata, +) +from .model_service import ( + UploadModelRequest, + UploadModelOperationMetadata, + UploadModelResponse, + GetModelRequest, + ListModelsRequest, + ListModelsResponse, + UpdateModelRequest, + DeleteModelRequest, + ExportModelRequest, + ExportModelOperationMetadata, + ExportModelResponse, + GetModelEvaluationRequest, + ListModelEvaluationsRequest, + ListModelEvaluationsResponse, + GetModelEvaluationSliceRequest, + ListModelEvaluationSlicesRequest, + ListModelEvaluationSlicesResponse, +) +from .data_item import DataItem +from .dataset_service import ( + CreateDatasetRequest, + CreateDatasetOperationMetadata, + GetDatasetRequest, + UpdateDatasetRequest, + ListDatasetsRequest, + ListDatasetsResponse, + DeleteDatasetRequest, + ImportDataRequest, + ImportDataResponse, + ImportDataOperationMetadata, + ExportDataRequest, + ExportDataResponse, + ExportDataOperationMetadata, + ListDataItemsRequest, + ListDataItemsResponse, + GetAnnotationSpecRequest, + ListAnnotationsRequest, + ListAnnotationsResponse, +) __all__ = ( - 'AnnotationSpec', - 'GcsSource', - 'GcsDestination', - 'BigQuerySource', - 'BigQueryDestination', - 'ContainerRegistryDestination', - 'Dataset', - 'ImportDataConfig', - 'ExportDataConfig', - 'ManualBatchTuningParameters', - 'CompletionStats', - 'ModelEvaluationSlice', - 'MachineSpec', - 'DedicatedResources', - 'AutomaticResources', - 'BatchDedicatedResources', - 'ResourcesConsumed', - 'DeployedModelRef', - 'EnvVar', - 'ExplanationMetadata', - 'Explanation', - 'ModelExplanation', - 'Attribution', - 'ExplanationSpec', - 'ExplanationParameters', - 'SampledShapleyAttribution', - 'Model', - 'PredictSchemata', - 'ModelContainerSpec', - 'Port', - 'TrainingPipeline', - 'InputDataConfig', - 'FractionSplit', - 'FilterSplit', - 'PredefinedSplit', - 'TimestampSplit', - 'ModelEvaluation', - 'MigratableResource', - 'GenericOperationMetadata', - 'DeleteOperationMetadata', - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'BatchMigrateResourcesRequest', - 'MigrateResourceRequest', - 'BatchMigrateResourcesResponse', - 'MigrateResourceResponse', - 'BatchMigrateResourcesOperationMetadata', - 'BatchPredictionJob', - 'CustomJob', - 'CustomJobSpec', - 'WorkerPoolSpec', - 'ContainerSpec', - 'PythonPackageSpec', - 'Scheduling', - 'SpecialistPool', - 'DataLabelingJob', - 'ActiveLearningConfig', - 'SampleConfig', - 'TrainingConfig', - 'Trial', - 'StudySpec', - 'Measurement', - 'HyperparameterTuningJob', - 'CreateCustomJobRequest', - 'GetCustomJobRequest', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'DeleteCustomJobRequest', - 'CancelCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'GetDataLabelingJobRequest', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'DeleteDataLabelingJobRequest', - 'CancelDataLabelingJobRequest', - 'CreateHyperparameterTuningJobRequest', - 'GetHyperparameterTuningJobRequest', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'DeleteHyperparameterTuningJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CreateBatchPredictionJobRequest', - 'GetBatchPredictionJobRequest', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'DeleteBatchPredictionJobRequest', - 'CancelBatchPredictionJobRequest', - 'UserActionReference', - 'Annotation', - 'Endpoint', - 'DeployedModel', - 'PredictRequest', - 'PredictResponse', - 'ExplainRequest', - 'ExplainResponse', - 'CreateEndpointRequest', - 'CreateEndpointOperationMetadata', - 'GetEndpointRequest', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'UpdateEndpointRequest', - 'DeleteEndpointRequest', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UndeployModelOperationMetadata', - 'CreateTrainingPipelineRequest', - 'GetTrainingPipelineRequest', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'DeleteTrainingPipelineRequest', - 'CancelTrainingPipelineRequest', - 'CreateSpecialistPoolRequest', - 'CreateSpecialistPoolOperationMetadata', - 'GetSpecialistPoolRequest', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'DeleteSpecialistPoolRequest', - 'UpdateSpecialistPoolRequest', - 'UpdateSpecialistPoolOperationMetadata', - 'UploadModelRequest', - 'UploadModelOperationMetadata', - 'UploadModelResponse', - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'UpdateModelRequest', - 'DeleteModelRequest', - 'ExportModelRequest', - 'ExportModelOperationMetadata', - 'ExportModelResponse', - 'GetModelEvaluationRequest', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'GetModelEvaluationSliceRequest', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', - 'DataItem', - 'CreateDatasetRequest', - 'CreateDatasetOperationMetadata', - 'GetDatasetRequest', - 'UpdateDatasetRequest', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'DeleteDatasetRequest', - 'ImportDataRequest', - 'ImportDataResponse', - 'ImportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportDataOperationMetadata', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'GetAnnotationSpecRequest', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', + "AnnotationSpec", + "GcsSource", + "GcsDestination", + "BigQuerySource", + "BigQueryDestination", + "ContainerRegistryDestination", + "Dataset", + "ImportDataConfig", + "ExportDataConfig", + "ManualBatchTuningParameters", + "CompletionStats", + "ModelEvaluationSlice", + "MachineSpec", + "DedicatedResources", + "AutomaticResources", + "BatchDedicatedResources", + "ResourcesConsumed", + "DeployedModelRef", + "EnvVar", + "ExplanationMetadata", + "Explanation", + "ModelExplanation", + "Attribution", + "ExplanationSpec", + "ExplanationParameters", + "SampledShapleyAttribution", + "Model", + "PredictSchemata", + "ModelContainerSpec", + "Port", + "TrainingPipeline", + "InputDataConfig", + "FractionSplit", + "FilterSplit", + "PredefinedSplit", + "TimestampSplit", + "ModelEvaluation", + "MigratableResource", + "GenericOperationMetadata", + "DeleteOperationMetadata", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", + "BatchPredictionJob", + "CustomJob", + "CustomJobSpec", + "WorkerPoolSpec", + "ContainerSpec", + "PythonPackageSpec", + "Scheduling", + "SpecialistPool", + "DataLabelingJob", + "ActiveLearningConfig", + "SampleConfig", + "TrainingConfig", + "Trial", + "StudySpec", + "Measurement", + "HyperparameterTuningJob", + "CreateCustomJobRequest", + "GetCustomJobRequest", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "DeleteCustomJobRequest", + "CancelCustomJobRequest", + "CreateDataLabelingJobRequest", + "GetDataLabelingJobRequest", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "DeleteDataLabelingJobRequest", + "CancelDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "GetHyperparameterTuningJobRequest", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "DeleteHyperparameterTuningJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "GetBatchPredictionJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "DeleteBatchPredictionJobRequest", + "CancelBatchPredictionJobRequest", + "UserActionReference", + "Annotation", + "Endpoint", + "DeployedModel", + "PredictRequest", + "PredictResponse", + "ExplainRequest", + "ExplainResponse", + "CreateEndpointRequest", + "CreateEndpointOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UpdateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelRequest", + "DeployModelResponse", + "DeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UndeployModelOperationMetadata", + "CreateTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "DeleteTrainingPipelineRequest", + "CancelTrainingPipelineRequest", + "CreateSpecialistPoolRequest", + "CreateSpecialistPoolOperationMetadata", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "DeleteSpecialistPoolRequest", + "UpdateSpecialistPoolRequest", + "UpdateSpecialistPoolOperationMetadata", + "UploadModelRequest", + "UploadModelOperationMetadata", + "UploadModelResponse", + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "DeleteModelRequest", + "ExportModelRequest", + "ExportModelOperationMetadata", + "ExportModelResponse", + "GetModelEvaluationRequest", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "GetModelEvaluationSliceRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "DataItem", + "CreateDatasetRequest", + "CreateDatasetOperationMetadata", + "GetDatasetRequest", + "UpdateDatasetRequest", + "ListDatasetsRequest", + "ListDatasetsResponse", + "DeleteDatasetRequest", + "ImportDataRequest", + "ImportDataResponse", + "ImportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportDataOperationMetadata", + "ListDataItemsRequest", + "ListDataItemsResponse", + "GetAnnotationSpecRequest", + "ListAnnotationsRequest", + "ListAnnotationsResponse", ) diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py index e82a142396..337b0eeaf5 100644 --- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'AcceleratorType', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"AcceleratorType",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index f3f36fb568..7734fcc512 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Annotation', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Annotation",}, ) @@ -94,22 +91,16 @@ class Annotation(proto.Message): payload_schema_uri = proto.Field(proto.STRING, number=2) - payload = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=8) - annotation_source = proto.Field(proto.MESSAGE, number=5, - message=user_action_reference.UserActionReference, + annotation_source = proto.Field( + proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, ) labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index 068aca741b..a5a4b3d489 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'AnnotationSpec', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"AnnotationSpec",}, ) @@ -58,13 +55,9 @@ class AnnotationSpec(proto.Message): display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 55c81889e7..2f464e6c8f 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -18,21 +18,22 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import completion_stats as gca_completion_stats +from google.cloud.aiplatform_v1beta1.types import ( + completion_stats as gca_completion_stats, +) from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import ( + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, +) from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'BatchPredictionJob', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"BatchPredictionJob",}, ) @@ -153,6 +154,7 @@ class BatchPredictionJob(proto.Message): See https://goo.gl/xmQnxf for more information and examples of labels. """ + class InputConfig(proto.Message): r"""Configures the input to ``BatchPredictionJob``. @@ -179,12 +181,12 @@ class InputConfig(proto.Message): ``supported_input_storage_formats``. """ - gcs_source = proto.Field(proto.MESSAGE, number=2, oneof='source', - message=io.GcsSource, + gcs_source = proto.Field( + proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, ) - bigquery_source = proto.Field(proto.MESSAGE, number=3, oneof='source', - message=io.BigQuerySource, + bigquery_source = proto.Field( + proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, ) instances_format = proto.Field(proto.STRING, number=1) @@ -255,11 +257,14 @@ class OutputConfig(proto.Message): ``supported_output_storage_formats``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=2, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, ) - bigquery_destination = proto.Field(proto.MESSAGE, number=3, oneof='destination', + bigquery_destination = proto.Field( + proto.MESSAGE, + number=3, + oneof="destination", message=io.BigQueryDestination, ) @@ -280,9 +285,13 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field(proto.STRING, number=1, oneof='output_location') + gcs_output_directory = proto.Field( + proto.STRING, number=1, oneof="output_location" + ) - bigquery_output_dataset = proto.Field(proto.STRING, number=2, oneof='output_location') + bigquery_output_dataset = proto.Field( + proto.STRING, number=2, oneof="output_location" + ) name = proto.Field(proto.STRING, number=1) @@ -290,67 +299,49 @@ class OutputInfo(proto.Message): model = proto.Field(proto.STRING, number=3) - input_config = proto.Field(proto.MESSAGE, number=4, - message=InputConfig, - ) + input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) - model_parameters = proto.Field(proto.MESSAGE, number=5, - message=struct.Value, - ) + model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - output_config = proto.Field(proto.MESSAGE, number=6, - message=OutputConfig, - ) + output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) - dedicated_resources = proto.Field(proto.MESSAGE, number=7, - message=machine_resources.BatchDedicatedResources, + dedicated_resources = proto.Field( + proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, ) - manual_batch_tuning_parameters = proto.Field(proto.MESSAGE, number=8, + manual_batch_tuning_parameters = proto.Field( + proto.MESSAGE, + number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) generate_explanation = proto.Field(proto.BOOL, number=23) - output_info = proto.Field(proto.MESSAGE, number=9, - message=OutputInfo, - ) + output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) - state = proto.Field(proto.ENUM, number=10, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - error = proto.Field(proto.MESSAGE, number=11, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) - partial_failures = proto.RepeatedField(proto.MESSAGE, number=12, - message=status.Status, + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=12, message=status.Status, ) - resources_consumed = proto.Field(proto.MESSAGE, number=13, - message=machine_resources.ResourcesConsumed, + resources_consumed = proto.Field( + proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, ) - completion_stats = proto.Field(proto.MESSAGE, number=14, - message=gca_completion_stats.CompletionStats, + completion_stats = proto.Field( + proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, ) - create_time = proto.Field(proto.MESSAGE, number=15, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=16, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=17, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=18, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=19) diff --git a/google/cloud/aiplatform_v1beta1/types/completion_stats.py b/google/cloud/aiplatform_v1beta1/types/completion_stats.py index 3874f412df..165be59634 100644 --- a/google/cloud/aiplatform_v1beta1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/completion_stats.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'CompletionStats', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"CompletionStats",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 7ab803bec1..c8147f9d70 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -27,14 +27,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CustomJob', - 'CustomJobSpec', - 'WorkerPoolSpec', - 'ContainerSpec', - 'PythonPackageSpec', - 'Scheduling', + "CustomJob", + "CustomJobSpec", + "WorkerPoolSpec", + "ContainerSpec", + "PythonPackageSpec", + "Scheduling", }, ) @@ -89,33 +89,19 @@ class CustomJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - job_spec = proto.Field(proto.MESSAGE, number=4, - message='CustomJobSpec', - ) + job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) - state = proto.Field(proto.ENUM, number=5, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=10, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=11) @@ -162,16 +148,14 @@ class CustomJobSpec(proto.Message): ``//logs/`` """ - worker_pool_specs = proto.RepeatedField(proto.MESSAGE, number=1, - message='WorkerPoolSpec', + worker_pool_specs = proto.RepeatedField( + proto.MESSAGE, number=1, message="WorkerPoolSpec", ) - scheduling = proto.Field(proto.MESSAGE, number=3, - message='Scheduling', - ) + scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) - base_output_directory = proto.Field(proto.MESSAGE, number=6, - message=io.GcsDestination, + base_output_directory = proto.Field( + proto.MESSAGE, number=6, message=io.GcsDestination, ) @@ -191,16 +175,16 @@ class WorkerPoolSpec(proto.Message): use for this worker pool. """ - container_spec = proto.Field(proto.MESSAGE, number=6, oneof='task', - message='ContainerSpec', + container_spec = proto.Field( + proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", ) - python_package_spec = proto.Field(proto.MESSAGE, number=7, oneof='task', - message='PythonPackageSpec', + python_package_spec = proto.Field( + proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", ) - machine_spec = proto.Field(proto.MESSAGE, number=1, - message=machine_resources.MachineSpec, + machine_spec = proto.Field( + proto.MESSAGE, number=1, message=machine_resources.MachineSpec, ) replica_count = proto.Field(proto.INT64, number=2) @@ -278,9 +262,7 @@ class Scheduling(proto.Message): to workers leaving and joining a job. """ - timeout = proto.Field(proto.MESSAGE, number=1, - message=duration.Duration, - ) + timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index 961a153172..e43a944d94 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'DataItem', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"DataItem",}, ) @@ -73,19 +70,13 @@ class DataItem(proto.Message): name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=2, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=3) - payload = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index 2d10060738..d94efba1b0 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -25,12 +25,12 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'DataLabelingJob', - 'ActiveLearningConfig', - 'SampleConfig', - 'TrainingConfig', + "DataLabelingJob", + "ActiveLearningConfig", + "SampleConfig", + "TrainingConfig", }, ) @@ -141,34 +141,24 @@ class DataLabelingJob(proto.Message): inputs_schema_uri = proto.Field(proto.STRING, number=6) - inputs = proto.Field(proto.MESSAGE, number=7, - message=struct.Value, - ) + inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) - state = proto.Field(proto.ENUM, number=8, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) labeling_progress = proto.Field(proto.INT32, number=13) - current_spend = proto.Field(proto.MESSAGE, number=14, - message=money.Money, - ) + current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) - create_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=10, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=11) specialist_pools = proto.RepeatedField(proto.STRING, number=16) - active_learning_config = proto.Field(proto.MESSAGE, number=21, - message='ActiveLearningConfig', + active_learning_config = proto.Field( + proto.MESSAGE, number=21, message="ActiveLearningConfig", ) @@ -197,18 +187,18 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field(proto.INT64, number=1, oneof='human_labeling_budget') - - max_data_item_percentage = proto.Field(proto.INT32, number=2, oneof='human_labeling_budget') - - sample_config = proto.Field(proto.MESSAGE, number=3, - message='SampleConfig', + max_data_item_count = proto.Field( + proto.INT64, number=1, oneof="human_labeling_budget" ) - training_config = proto.Field(proto.MESSAGE, number=4, - message='TrainingConfig', + max_data_item_percentage = proto.Field( + proto.INT32, number=2, oneof="human_labeling_budget" ) + sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) + + training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) + class SampleConfig(proto.Message): r"""Active learning data sampling config. For every active @@ -228,6 +218,7 @@ class SampleConfig(proto.Message): strategy will decide which data should be selected for human labeling in every batch. """ + class SampleStrategy(proto.Enum): r"""Sample strategy decides which subset of DataItems should be selected for human labeling in every batch. @@ -235,14 +226,16 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field(proto.INT32, number=1, oneof='initial_batch_sample_size') - - following_batch_sample_percentage = proto.Field(proto.INT32, number=3, oneof='following_batch_sample_size') + initial_batch_sample_percentage = proto.Field( + proto.INT32, number=1, oneof="initial_batch_sample_size" + ) - sample_strategy = proto.Field(proto.ENUM, number=5, - enum=SampleStrategy, + following_batch_sample_percentage = proto.Field( + proto.INT32, number=3, oneof="following_batch_sample_size" ) + sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) + class TrainingConfig(proto.Message): r"""CMLE training config. For every active learning labeling diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 5138badf1f..5840df17f3 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -24,12 +24,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Dataset', - 'ImportDataConfig', - 'ExportDataConfig', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"Dataset", "ImportDataConfig", "ExportDataConfig",}, ) @@ -92,17 +88,11 @@ class Dataset(proto.Message): metadata_schema_uri = proto.Field(proto.STRING, number=3) - metadata = proto.Field(proto.MESSAGE, number=8, - message=struct.Value, - ) + metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) etag = proto.Field(proto.STRING, number=6) @@ -141,8 +131,8 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field(proto.MESSAGE, number=1, oneof='source', - message=io.GcsSource, + gcs_source = proto.Field( + proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, ) data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) @@ -174,8 +164,8 @@ class ExportDataConfig(proto.Message): ``ListAnnotations``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=1, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, ) annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 594484375c..7160b7b52f 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -26,26 +26,26 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateDatasetRequest', - 'CreateDatasetOperationMetadata', - 'GetDatasetRequest', - 'UpdateDatasetRequest', - 'ListDatasetsRequest', - 'ListDatasetsResponse', - 'DeleteDatasetRequest', - 'ImportDataRequest', - 'ImportDataResponse', - 'ImportDataOperationMetadata', - 'ExportDataRequest', - 'ExportDataResponse', - 'ExportDataOperationMetadata', - 'ListDataItemsRequest', - 'ListDataItemsResponse', - 'GetAnnotationSpecRequest', - 'ListAnnotationsRequest', - 'ListAnnotationsResponse', + "CreateDatasetRequest", + "CreateDatasetOperationMetadata", + "GetDatasetRequest", + "UpdateDatasetRequest", + "ListDatasetsRequest", + "ListDatasetsResponse", + "DeleteDatasetRequest", + "ImportDataRequest", + "ImportDataResponse", + "ImportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportDataOperationMetadata", + "ListDataItemsRequest", + "ListDataItemsResponse", + "GetAnnotationSpecRequest", + "ListAnnotationsRequest", + "ListAnnotationsResponse", }, ) @@ -65,9 +65,7 @@ class CreateDatasetRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - dataset = proto.Field(proto.MESSAGE, number=2, - message=gca_dataset.Dataset, - ) + dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) class CreateDatasetOperationMetadata(proto.Message): @@ -79,8 +77,8 @@ class CreateDatasetOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -97,9 +95,7 @@ class GetDatasetRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateDatasetRequest(proto.Message): @@ -122,13 +118,9 @@ class UpdateDatasetRequest(proto.Message): - ``labels`` """ - dataset = proto.Field(proto.MESSAGE, number=1, - message=gca_dataset.Dataset, - ) + dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class ListDatasetsRequest(proto.Message): @@ -165,9 +157,7 @@ class ListDatasetsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -188,8 +178,8 @@ class ListDatasetsResponse(proto.Message): def raw_page(self): return self - datasets = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_dataset.Dataset, + datasets = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_dataset.Dataset, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -225,8 +215,8 @@ class ImportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - import_configs = proto.RepeatedField(proto.MESSAGE, number=2, - message=gca_dataset.ImportDataConfig, + import_configs = proto.RepeatedField( + proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, ) @@ -245,8 +235,8 @@ class ImportDataOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -264,8 +254,8 @@ class ExportDataRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - export_config = proto.Field(proto.MESSAGE, number=2, - message=gca_dataset.ExportDataConfig, + export_config = proto.Field( + proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, ) @@ -295,8 +285,8 @@ class ExportDataOperationMetadata(proto.Message): the directory. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -333,9 +323,7 @@ class ListDataItemsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -356,8 +344,8 @@ class ListDataItemsResponse(proto.Message): def raw_page(self): return self - data_items = proto.RepeatedField(proto.MESSAGE, number=1, - message=data_item.DataItem, + data_items = proto.RepeatedField( + proto.MESSAGE, number=1, message=data_item.DataItem, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -378,9 +366,7 @@ class GetAnnotationSpecRequest(proto.Message): name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class ListAnnotationsRequest(proto.Message): @@ -415,9 +401,7 @@ class ListAnnotationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -438,8 +422,8 @@ class ListAnnotationsResponse(proto.Message): def raw_page(self): return self - annotations = proto.RepeatedField(proto.MESSAGE, number=1, - message=annotation.Annotation, + annotations = proto.RepeatedField( + proto.MESSAGE, number=1, message=annotation.Annotation, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py index aa5c8424aa..b0ec7010a2 100644 --- a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'DeployedModelRef', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"DeployedModelRef",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 7d1275703d..07f6a2c61b 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -24,11 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Endpoint', - 'DeployedModel', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"Endpoint", "DeployedModel",}, ) @@ -91,8 +87,8 @@ class Endpoint(proto.Message): description = proto.Field(proto.STRING, number=3) - deployed_models = proto.RepeatedField(proto.MESSAGE, number=4, - message='DeployedModel', + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=4, message="DeployedModel", ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) @@ -101,13 +97,9 @@ class Endpoint(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=7) - create_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) class DeployedModel(proto.Message): @@ -174,11 +166,17 @@ class DeployedModel(proto.Message): option. """ - dedicated_resources = proto.Field(proto.MESSAGE, number=7, oneof='prediction_resources', + dedicated_resources = proto.Field( + proto.MESSAGE, + number=7, + oneof="prediction_resources", message=machine_resources.DedicatedResources, ) - automatic_resources = proto.Field(proto.MESSAGE, number=8, oneof='prediction_resources', + automatic_resources = proto.Field( + proto.MESSAGE, + number=8, + oneof="prediction_resources", message=machine_resources.AutomaticResources, ) @@ -188,12 +186,10 @@ class DeployedModel(proto.Message): display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - explanation_spec = proto.Field(proto.MESSAGE, number=9, - message=explanation.ExplanationSpec, + explanation_spec = proto.Field( + proto.MESSAGE, number=9, message=explanation.ExplanationSpec, ) enable_container_logging = proto.Field(proto.BOOL, number=12) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index acbf58d123..4bc9f35594 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -24,21 +24,21 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateEndpointRequest', - 'CreateEndpointOperationMetadata', - 'GetEndpointRequest', - 'ListEndpointsRequest', - 'ListEndpointsResponse', - 'UpdateEndpointRequest', - 'DeleteEndpointRequest', - 'DeployModelRequest', - 'DeployModelResponse', - 'DeployModelOperationMetadata', - 'UndeployModelRequest', - 'UndeployModelResponse', - 'UndeployModelOperationMetadata', + "CreateEndpointRequest", + "CreateEndpointOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UpdateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelRequest", + "DeployModelResponse", + "DeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UndeployModelOperationMetadata", }, ) @@ -58,9 +58,7 @@ class CreateEndpointRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - endpoint = proto.Field(proto.MESSAGE, number=2, - message=gca_endpoint.Endpoint, - ) + endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) class CreateEndpointOperationMetadata(proto.Message): @@ -72,8 +70,8 @@ class CreateEndpointOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -145,9 +143,7 @@ class ListEndpointsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListEndpointsResponse(proto.Message): @@ -167,8 +163,8 @@ class ListEndpointsResponse(proto.Message): def raw_page(self): return self - endpoints = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_endpoint.Endpoint, + endpoints = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -187,13 +183,9 @@ class UpdateEndpointRequest(proto.Message): resource. """ - endpoint = proto.Field(proto.MESSAGE, number=1, - message=gca_endpoint.Endpoint, - ) + endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteEndpointRequest(proto.Message): @@ -246,8 +238,8 @@ class DeployModelRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - deployed_model = proto.Field(proto.MESSAGE, number=2, - message=gca_endpoint.DeployedModel, + deployed_model = proto.Field( + proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, ) traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -263,8 +255,8 @@ class DeployModelResponse(proto.Message): the Endpoint. """ - deployed_model = proto.Field(proto.MESSAGE, number=1, - message=gca_endpoint.DeployedModel, + deployed_model = proto.Field( + proto.MESSAGE, number=1, message=gca_endpoint.DeployedModel, ) @@ -277,8 +269,8 @@ class DeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -327,8 +319,8 @@ class UndeployModelOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/env_var.py b/google/cloud/aiplatform_v1beta1/types/env_var.py index 3eb6531af1..207e8275cd 100644 --- a/google/cloud/aiplatform_v1beta1/types/env_var.py +++ b/google/cloud/aiplatform_v1beta1/types/env_var.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'EnvVar', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"EnvVar",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 41778b055c..06b930d90c 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -23,14 +23,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'Explanation', - 'ModelExplanation', - 'Attribution', - 'ExplanationSpec', - 'ExplanationParameters', - 'SampledShapleyAttribution', + "Explanation", + "ModelExplanation", + "Attribution", + "ExplanationSpec", + "ExplanationParameters", + "SampledShapleyAttribution", }, ) @@ -59,9 +59,7 @@ class Explanation(proto.Message): explaining. """ - attributions = proto.RepeatedField(proto.MESSAGE, number=1, - message='Attribution', - ) + attributions = proto.RepeatedField(proto.MESSAGE, number=1, message="Attribution",) class ModelExplanation(proto.Message): @@ -100,8 +98,8 @@ class ModelExplanation(proto.Message): is not populated. """ - mean_attributions = proto.RepeatedField(proto.MESSAGE, number=1, - message='Attribution', + mean_attributions = proto.RepeatedField( + proto.MESSAGE, number=1, message="Attribution", ) @@ -209,9 +207,7 @@ class Attribution(proto.Message): instance_output_value = proto.Field(proto.DOUBLE, number=2) - feature_attributions = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + feature_attributions = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) output_index = proto.RepeatedField(proto.INT32, number=4) @@ -233,12 +229,10 @@ class ExplanationSpec(proto.Message): input and output for explanation. """ - parameters = proto.Field(proto.MESSAGE, number=1, - message='ExplanationParameters', - ) + parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) - metadata = proto.Field(proto.MESSAGE, number=2, - message=explanation_metadata.ExplanationMetadata, + metadata = proto.Field( + proto.MESSAGE, number=2, message=explanation_metadata.ExplanationMetadata, ) @@ -254,8 +248,8 @@ class ExplanationParameters(proto.Message): considering all subsets of features. """ - sampled_shapley_attribution = proto.Field(proto.MESSAGE, number=1, - message='SampledShapleyAttribution', + sampled_shapley_attribution = proto.Field( + proto.MESSAGE, number=1, message="SampledShapleyAttribution", ) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index a6fed18554..cc60c125be 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ExplanationMetadata', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"ExplanationMetadata",}, ) @@ -61,6 +58,7 @@ class ExplanationMetadata(proto.Message): output URI will point to a location where the user only has a read access. """ + class InputMetadata(proto.Message): r"""Metadata of the input of a feature. @@ -83,8 +81,8 @@ class InputMetadata(proto.Message): ``instance_schema_uri``. """ - input_baselines = proto.RepeatedField(proto.MESSAGE, number=1, - message=struct.Value, + input_baselines = proto.RepeatedField( + proto.MESSAGE, number=1, message=struct.Value, ) class OutputMetadata(proto.Message): @@ -120,18 +118,20 @@ class OutputMetadata(proto.Message): for a specific output. """ - index_display_name_mapping = proto.Field(proto.MESSAGE, number=1, oneof='display_name_mapping', - message=struct.Value, + index_display_name_mapping = proto.Field( + proto.MESSAGE, number=1, oneof="display_name_mapping", message=struct.Value, ) - display_name_mapping_key = proto.Field(proto.STRING, number=2, oneof='display_name_mapping') + display_name_mapping_key = proto.Field( + proto.STRING, number=2, oneof="display_name_mapping" + ) - inputs = proto.MapField(proto.STRING, proto.MESSAGE, number=1, - message=InputMetadata, + inputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=1, message=InputMetadata, ) - outputs = proto.MapField(proto.STRING, proto.MESSAGE, number=2, - message=OutputMetadata, + outputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=2, message=OutputMetadata, ) feature_attributions_schema_uri = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py index e421cbe615..78af635e79 100644 --- a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py @@ -26,10 +26,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'HyperparameterTuningJob', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"HyperparameterTuningJob",}, ) @@ -102,9 +99,7 @@ class HyperparameterTuningJob(proto.Message): display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=4, - message=study.StudySpec, - ) + study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) max_trial_count = proto.Field(proto.INT32, number=5) @@ -112,37 +107,23 @@ class HyperparameterTuningJob(proto.Message): max_failed_trial_count = proto.Field(proto.INT32, number=7) - trial_job_spec = proto.Field(proto.MESSAGE, number=8, - message=custom_job.CustomJobSpec, + trial_job_spec = proto.Field( + proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, ) - trials = proto.RepeatedField(proto.MESSAGE, number=9, - message=study.Trial, - ) + trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) - state = proto.Field(proto.ENUM, number=10, - enum=job_state.JobState, - ) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=15, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) labels = proto.MapField(proto.STRING, proto.STRING, number=16) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index 7e47f3e3f7..f5fcc170f9 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -19,13 +19,13 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'GcsSource', - 'GcsDestination', - 'BigQuerySource', - 'BigQueryDestination', - 'ContainerRegistryDestination', + "GcsSource", + "GcsDestination", + "BigQuerySource", + "BigQueryDestination", + "ContainerRegistryDestination", }, ) diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index 45c303431a..f64f07cbe3 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -18,40 +18,46 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateCustomJobRequest', - 'GetCustomJobRequest', - 'ListCustomJobsRequest', - 'ListCustomJobsResponse', - 'DeleteCustomJobRequest', - 'CancelCustomJobRequest', - 'CreateDataLabelingJobRequest', - 'GetDataLabelingJobRequest', - 'ListDataLabelingJobsRequest', - 'ListDataLabelingJobsResponse', - 'DeleteDataLabelingJobRequest', - 'CancelDataLabelingJobRequest', - 'CreateHyperparameterTuningJobRequest', - 'GetHyperparameterTuningJobRequest', - 'ListHyperparameterTuningJobsRequest', - 'ListHyperparameterTuningJobsResponse', - 'DeleteHyperparameterTuningJobRequest', - 'CancelHyperparameterTuningJobRequest', - 'CreateBatchPredictionJobRequest', - 'GetBatchPredictionJobRequest', - 'ListBatchPredictionJobsRequest', - 'ListBatchPredictionJobsResponse', - 'DeleteBatchPredictionJobRequest', - 'CancelBatchPredictionJobRequest', + "CreateCustomJobRequest", + "GetCustomJobRequest", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "DeleteCustomJobRequest", + "CancelCustomJobRequest", + "CreateDataLabelingJobRequest", + "GetDataLabelingJobRequest", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "DeleteDataLabelingJobRequest", + "CancelDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "GetHyperparameterTuningJobRequest", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "DeleteHyperparameterTuningJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "GetBatchPredictionJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "DeleteBatchPredictionJobRequest", + "CancelBatchPredictionJobRequest", }, ) @@ -71,9 +77,7 @@ class CreateCustomJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - custom_job = proto.Field(proto.MESSAGE, number=2, - message=gca_custom_job.CustomJob, - ) + custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) class GetCustomJobRequest(proto.Message): @@ -136,9 +140,7 @@ class ListCustomJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListCustomJobsResponse(proto.Message): @@ -158,8 +160,8 @@ class ListCustomJobsResponse(proto.Message): def raw_page(self): return self - custom_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_custom_job.CustomJob, + custom_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -206,8 +208,8 @@ class CreateDataLabelingJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field(proto.MESSAGE, number=2, - message=gca_data_labeling_job.DataLabelingJob, + data_labeling_job = proto.Field( + proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, ) @@ -273,9 +275,7 @@ class ListDataLabelingJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) order_by = proto.Field(proto.STRING, number=6) @@ -296,8 +296,8 @@ class ListDataLabelingJobsResponse(proto.Message): def raw_page(self): return self - data_labeling_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_data_labeling_job.DataLabelingJob, + data_labeling_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -348,7 +348,9 @@ class CreateHyperparameterTuningJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - hyperparameter_tuning_job = proto.Field(proto.MESSAGE, number=2, + hyperparameter_tuning_job = proto.Field( + proto.MESSAGE, + number=2, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -415,9 +417,7 @@ class ListHyperparameterTuningJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListHyperparameterTuningJobsResponse(proto.Message): @@ -439,7 +439,9 @@ class ListHyperparameterTuningJobsResponse(proto.Message): def raw_page(self): return self - hyperparameter_tuning_jobs = proto.RepeatedField(proto.MESSAGE, number=1, + hyperparameter_tuning_jobs = proto.RepeatedField( + proto.MESSAGE, + number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) @@ -491,8 +493,8 @@ class CreateBatchPredictionJobRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - batch_prediction_job = proto.Field(proto.MESSAGE, number=2, - message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_job = proto.Field( + proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -558,9 +560,7 @@ class ListBatchPredictionJobsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListBatchPredictionJobsResponse(proto.Message): @@ -581,8 +581,8 @@ class ListBatchPredictionJobsResponse(proto.Message): def raw_page(self): return self - batch_prediction_jobs = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_batch_prediction_job.BatchPredictionJob, + batch_prediction_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/job_state.py b/google/cloud/aiplatform_v1beta1/types/job_state.py index f23f7f60cd..f86e179b1b 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_state.py +++ b/google/cloud/aiplatform_v1beta1/types/job_state.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'JobState', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"JobState",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index 88aea166a2..f713cd2f64 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -18,17 +18,19 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import accelerator_type as gca_accelerator_type +from google.cloud.aiplatform_v1beta1.types import ( + accelerator_type as gca_accelerator_type, +) __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'MachineSpec', - 'DedicatedResources', - 'AutomaticResources', - 'BatchDedicatedResources', - 'ResourcesConsumed', + "MachineSpec", + "DedicatedResources", + "AutomaticResources", + "BatchDedicatedResources", + "ResourcesConsumed", }, ) @@ -88,8 +90,8 @@ class MachineSpec(proto.Message): machine_type = proto.Field(proto.STRING, number=1) - accelerator_type = proto.Field(proto.ENUM, number=2, - enum=gca_accelerator_type.AcceleratorType, + accelerator_type = proto.Field( + proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, ) accelerator_count = proto.Field(proto.INT32, number=3) @@ -128,9 +130,7 @@ class DedicatedResources(proto.Message): as the default value. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, - message=MachineSpec, - ) + machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) min_replica_count = proto.Field(proto.INT32, number=2) @@ -194,9 +194,7 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, - message=MachineSpec, - ) + machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) starting_replica_count = proto.Field(proto.INT32, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py index da5c4d38ab..7a467d5069 100644 --- a/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py +++ b/google/cloud/aiplatform_v1beta1/types/manual_batch_tuning_parameters.py @@ -19,10 +19,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ManualBatchTuningParameters', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"ManualBatchTuningParameters",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py index a96f6d420f..99a6e65a42 100644 --- a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py +++ b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py @@ -22,10 +22,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'MigratableResource', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"MigratableResource",}, ) @@ -55,6 +52,7 @@ class MigratableResource(proto.Message): Output only. Timestamp when this MigratableResource was last updated. """ + class MlEngineModelVersion(proto.Message): r"""Represents one model Version in ml.googleapis.com. @@ -123,6 +121,7 @@ class DataLabelingDataset(proto.Message): datalabeling.googleapis.com belongs to the data labeling Dataset. """ + class DataLabelingAnnotatedDataset(proto.Message): r"""Represents one AnnotatedDataset in datalabeling.googleapis.com. @@ -146,32 +145,34 @@ class DataLabelingAnnotatedDataset(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=4) - data_labeling_annotated_datasets = proto.RepeatedField(proto.MESSAGE, number=3, - message='MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset', + data_labeling_annotated_datasets = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset", ) - ml_engine_model_version = proto.Field(proto.MESSAGE, number=1, oneof='resource', - message=MlEngineModelVersion, + ml_engine_model_version = proto.Field( + proto.MESSAGE, number=1, oneof="resource", message=MlEngineModelVersion, ) - automl_model = proto.Field(proto.MESSAGE, number=2, oneof='resource', - message=AutomlModel, + automl_model = proto.Field( + proto.MESSAGE, number=2, oneof="resource", message=AutomlModel, ) - automl_dataset = proto.Field(proto.MESSAGE, number=3, oneof='resource', - message=AutomlDataset, + automl_dataset = proto.Field( + proto.MESSAGE, number=3, oneof="resource", message=AutomlDataset, ) - data_labeling_dataset = proto.Field(proto.MESSAGE, number=4, oneof='resource', - message=DataLabelingDataset, + data_labeling_dataset = proto.Field( + proto.MESSAGE, number=4, oneof="resource", message=DataLabelingDataset, ) - last_migrate_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, + last_migrate_time = proto.Field( + proto.MESSAGE, number=5, message=timestamp.Timestamp, ) - last_update_time = proto.Field(proto.MESSAGE, number=6, - message=timestamp.Timestamp, + last_update_time = proto.Field( + proto.MESSAGE, number=6, message=timestamp.Timestamp, ) diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py index 607629f06a..46b0cdc66b 100644 --- a/google/cloud/aiplatform_v1beta1/types/migration_service.py +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -18,20 +18,22 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import migratable_resource as gca_migratable_resource +from google.cloud.aiplatform_v1beta1.types import ( + migratable_resource as gca_migratable_resource, +) from google.cloud.aiplatform_v1beta1.types import operation __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'SearchMigratableResourcesRequest', - 'SearchMigratableResourcesResponse', - 'BatchMigrateResourcesRequest', - 'MigrateResourceRequest', - 'BatchMigrateResourcesResponse', - 'MigrateResourceResponse', - 'BatchMigrateResourcesOperationMetadata', + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", }, ) @@ -79,8 +81,8 @@ class SearchMigratableResourcesResponse(proto.Message): def raw_page(self): return self - migratable_resources = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_migratable_resource.MigratableResource, + migratable_resources = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_migratable_resource.MigratableResource, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -103,8 +105,8 @@ class BatchMigrateResourcesRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - migrate_resource_requests = proto.RepeatedField(proto.MESSAGE, number=2, - message='MigrateResourceRequest', + migrate_resource_requests = proto.RepeatedField( + proto.MESSAGE, number=2, message="MigrateResourceRequest", ) @@ -128,6 +130,7 @@ class MigrateResourceRequest(proto.Message): datalabeling.googleapis.com to AI Platform's Dataset. """ + class MigrateMlEngineModelVersionConfig(proto.Message): r"""Config for migrating version in ml.googleapis.com to AI Platform's Model. @@ -215,6 +218,7 @@ class MigrateDataLabelingDatasetConfig(proto.Message): AnnotatedDatasets have to belong to the datalabeling Dataset. """ + class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): r"""Config for migrating AnnotatedDataset in datalabeling.googleapis.com to AI Platform's SavedQuery. @@ -233,23 +237,31 @@ class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): dataset_display_name = proto.Field(proto.STRING, number=2) - migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField(proto.MESSAGE, number=3, - message='MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig', + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig", ) - migrate_ml_engine_model_version_config = proto.Field(proto.MESSAGE, number=1, oneof='request', + migrate_ml_engine_model_version_config = proto.Field( + proto.MESSAGE, + number=1, + oneof="request", message=MigrateMlEngineModelVersionConfig, ) - migrate_automl_model_config = proto.Field(proto.MESSAGE, number=2, oneof='request', - message=MigrateAutomlModelConfig, + migrate_automl_model_config = proto.Field( + proto.MESSAGE, number=2, oneof="request", message=MigrateAutomlModelConfig, ) - migrate_automl_dataset_config = proto.Field(proto.MESSAGE, number=3, oneof='request', - message=MigrateAutomlDatasetConfig, + migrate_automl_dataset_config = proto.Field( + proto.MESSAGE, number=3, oneof="request", message=MigrateAutomlDatasetConfig, ) - migrate_data_labeling_dataset_config = proto.Field(proto.MESSAGE, number=4, oneof='request', + migrate_data_labeling_dataset_config = proto.Field( + proto.MESSAGE, + number=4, + oneof="request", message=MigrateDataLabelingDatasetConfig, ) @@ -263,8 +275,8 @@ class BatchMigrateResourcesResponse(proto.Message): Successfully migrated resources. """ - migrate_resource_responses = proto.RepeatedField(proto.MESSAGE, number=1, - message='MigrateResourceResponse', + migrate_resource_responses = proto.RepeatedField( + proto.MESSAGE, number=1, message="MigrateResourceResponse", ) @@ -282,12 +294,12 @@ class MigrateResourceResponse(proto.Message): datalabeling.googleapis.com. """ - dataset = proto.Field(proto.STRING, number=1, oneof='migrated_resource') + dataset = proto.Field(proto.STRING, number=1, oneof="migrated_resource") - model = proto.Field(proto.STRING, number=2, oneof='migrated_resource') + model = proto.Field(proto.STRING, number=2, oneof="migrated_resource") - migratable_resource = proto.Field(proto.MESSAGE, number=3, - message=gca_migratable_resource.MigratableResource, + migratable_resource = proto.Field( + proto.MESSAGE, number=3, message=gca_migratable_resource.MigratableResource, ) @@ -300,8 +312,8 @@ class BatchMigrateResourcesOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index abd5f67e94..7fa9130909 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -26,13 +26,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Model', - 'PredictSchemata', - 'ModelContainerSpec', - 'Port', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"Model", "PredictSchemata", "ModelContainerSpec", "Port",}, ) @@ -232,6 +227,7 @@ class Model(proto.Message): See https://goo.gl/xmQnxf for more information and examples of labels. """ + class DeploymentResourcesType(proto.Enum): r"""Identifies a type of Model's prediction resources.""" DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = 0 @@ -268,6 +264,7 @@ class ExportFormat(proto.Message): Output only. The content of this Model that may be exported. """ + class ExportableContent(proto.Enum): r"""The Model content that can be exported.""" EXPORTABLE_CONTENT_UNSPECIFIED = 0 @@ -276,8 +273,8 @@ class ExportableContent(proto.Enum): id = proto.Field(proto.STRING, number=1) - exportable_contents = proto.RepeatedField(proto.ENUM, number=2, - enum='Model.ExportFormat.ExportableContent', + exportable_contents = proto.RepeatedField( + proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", ) name = proto.Field(proto.STRING, number=1) @@ -286,50 +283,40 @@ class ExportableContent(proto.Enum): description = proto.Field(proto.STRING, number=3) - predict_schemata = proto.Field(proto.MESSAGE, number=4, - message='PredictSchemata', - ) + predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) metadata_schema_uri = proto.Field(proto.STRING, number=5) - metadata = proto.Field(proto.MESSAGE, number=6, - message=struct.Value, - ) + metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - supported_export_formats = proto.RepeatedField(proto.MESSAGE, number=20, - message=ExportFormat, + supported_export_formats = proto.RepeatedField( + proto.MESSAGE, number=20, message=ExportFormat, ) training_pipeline = proto.Field(proto.STRING, number=7) - container_spec = proto.Field(proto.MESSAGE, number=9, - message='ModelContainerSpec', - ) + container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) artifact_uri = proto.Field(proto.STRING, number=26) - supported_deployment_resources_types = proto.RepeatedField(proto.ENUM, number=10, - enum=DeploymentResourcesType, + supported_deployment_resources_types = proto.RepeatedField( + proto.ENUM, number=10, enum=DeploymentResourcesType, ) supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) - create_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - deployed_models = proto.RepeatedField(proto.MESSAGE, number=15, - message=deployed_model_ref.DeployedModelRef, + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, ) - explanation_spec = proto.Field(proto.MESSAGE, number=23, - message=explanation.ExplanationSpec, + explanation_spec = proto.Field( + proto.MESSAGE, number=23, message=explanation.ExplanationSpec, ) etag = proto.Field(proto.STRING, number=16) @@ -657,13 +644,9 @@ class ModelContainerSpec(proto.Message): args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, - message=env_var.EnvVar, - ) + env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) - ports = proto.RepeatedField(proto.MESSAGE, number=5, - message='Port', - ) + ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) predict_route = proto.Field(proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py index 5613b3017d..b768ed978e 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py @@ -24,10 +24,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ModelEvaluation', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluation",}, ) @@ -74,18 +71,14 @@ class ModelEvaluation(proto.Message): metrics_schema_uri = proto.Field(proto.STRING, number=2) - metrics = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) slice_dimensions = proto.RepeatedField(proto.STRING, number=5) - model_explanation = proto.Field(proto.MESSAGE, number=8, - message=explanation.ModelExplanation, + model_explanation = proto.Field( + proto.MESSAGE, number=8, message=explanation.ModelExplanation, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py index 7d21157f1a..1039d32c1f 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py @@ -23,10 +23,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'ModelEvaluationSlice', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"ModelEvaluationSlice",}, ) @@ -57,6 +54,7 @@ class ModelEvaluationSlice(proto.Message): Output only. Timestamp when this ModelEvaluationSlice was created. """ + class Slice(proto.Message): r"""Definition of a slice. @@ -81,19 +79,13 @@ class Slice(proto.Message): name = proto.Field(proto.STRING, number=1) - slice_ = proto.Field(proto.MESSAGE, number=2, - message=Slice, - ) + slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) metrics_schema_uri = proto.Field(proto.STRING, number=3) - metrics = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=5, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index e5945e49c0..3cfb17ad2c 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -27,25 +27,25 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'UploadModelRequest', - 'UploadModelOperationMetadata', - 'UploadModelResponse', - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'UpdateModelRequest', - 'DeleteModelRequest', - 'ExportModelRequest', - 'ExportModelOperationMetadata', - 'ExportModelResponse', - 'GetModelEvaluationRequest', - 'ListModelEvaluationsRequest', - 'ListModelEvaluationsResponse', - 'GetModelEvaluationSliceRequest', - 'ListModelEvaluationSlicesRequest', - 'ListModelEvaluationSlicesResponse', + "UploadModelRequest", + "UploadModelOperationMetadata", + "UploadModelResponse", + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "DeleteModelRequest", + "ExportModelRequest", + "ExportModelOperationMetadata", + "ExportModelResponse", + "GetModelEvaluationRequest", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "GetModelEvaluationSliceRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", }, ) @@ -65,9 +65,7 @@ class UploadModelRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.MESSAGE, number=2, - message=gca_model.Model, - ) + model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) class UploadModelOperationMetadata(proto.Message): @@ -80,8 +78,8 @@ class UploadModelOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -143,9 +141,7 @@ class ListModelsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelsResponse(proto.Message): @@ -165,9 +161,7 @@ class ListModelsResponse(proto.Message): def raw_page(self): return self - models = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_model.Model, - ) + models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) next_page_token = proto.Field(proto.STRING, number=2) @@ -187,13 +181,9 @@ class UpdateModelRequest(proto.Message): [FieldMask](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#fieldmask). """ - model = proto.Field(proto.MESSAGE, number=1, - message=gca_model.Model, - ) + model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class DeleteModelRequest(proto.Message): @@ -222,6 +212,7 @@ class ExportModelRequest(proto.Message): Required. The desired output location and configuration. """ + class OutputConfig(proto.Message): r"""Output configuration for the Model export. @@ -253,19 +244,17 @@ class OutputConfig(proto.Message): export_format_id = proto.Field(proto.STRING, number=1) - artifact_destination = proto.Field(proto.MESSAGE, number=3, - message=io.GcsDestination, + artifact_destination = proto.Field( + proto.MESSAGE, number=3, message=io.GcsDestination, ) - image_destination = proto.Field(proto.MESSAGE, number=4, - message=io.ContainerRegistryDestination, + image_destination = proto.Field( + proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) - output_config = proto.Field(proto.MESSAGE, number=2, - message=OutputConfig, - ) + output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) class ExportModelOperationMetadata(proto.Message): @@ -280,6 +269,7 @@ class ExportModelOperationMetadata(proto.Message): Output only. Information further describing the output of this Model export. """ + class OutputInfo(proto.Message): r"""Further describes the output of the ExportModel. Supplements ``ExportModelRequest.OutputConfig``. @@ -301,13 +291,11 @@ class OutputInfo(proto.Message): image_output_uri = proto.Field(proto.STRING, number=3) - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - output_info = proto.Field(proto.MESSAGE, number=2, - message=OutputInfo, - ) + output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) class ExportModelResponse(proto.Message): @@ -362,9 +350,7 @@ class ListModelEvaluationsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelEvaluationsResponse(proto.Message): @@ -385,8 +371,8 @@ class ListModelEvaluationsResponse(proto.Message): def raw_page(self): return self - model_evaluations = proto.RepeatedField(proto.MESSAGE, number=1, - message=model_evaluation.ModelEvaluation, + model_evaluations = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -441,9 +427,7 @@ class ListModelEvaluationSlicesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListModelEvaluationSlicesResponse(proto.Message): @@ -464,8 +448,8 @@ class ListModelEvaluationSlicesResponse(proto.Message): def raw_page(self): return self - model_evaluation_slices = proto.RepeatedField(proto.MESSAGE, number=1, - message=model_evaluation_slice.ModelEvaluationSlice, + model_evaluation_slices = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index bf2a5906bd..12b2150c35 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -23,11 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'GenericOperationMetadata', - 'DeleteOperationMetadata', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"GenericOperationMetadata", "DeleteOperationMetadata",}, ) @@ -49,17 +46,13 @@ class GenericOperationMetadata(proto.Message): updated for the last time. """ - partial_failures = proto.RepeatedField(proto.MESSAGE, number=1, - message=status.Status, + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=1, message=status.Status, ) - create_time = proto.Field(proto.MESSAGE, number=2, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=3, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) class DeleteOperationMetadata(proto.Message): @@ -70,8 +63,8 @@ class DeleteOperationMetadata(proto.Message): The common part of the operation metadata. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index 727855e58a..9f0856732d 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -18,19 +18,21 @@ import proto # type: ignore -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.protobuf import field_mask_pb2 as field_mask # type: ignore __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateTrainingPipelineRequest', - 'GetTrainingPipelineRequest', - 'ListTrainingPipelinesRequest', - 'ListTrainingPipelinesResponse', - 'DeleteTrainingPipelineRequest', - 'CancelTrainingPipelineRequest', + "CreateTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "DeleteTrainingPipelineRequest", + "CancelTrainingPipelineRequest", }, ) @@ -50,8 +52,8 @@ class CreateTrainingPipelineRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - training_pipeline = proto.Field(proto.MESSAGE, number=2, - message=gca_training_pipeline.TrainingPipeline, + training_pipeline = proto.Field( + proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, ) @@ -114,9 +116,7 @@ class ListTrainingPipelinesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) class ListTrainingPipelinesResponse(proto.Message): @@ -137,8 +137,8 @@ class ListTrainingPipelinesResponse(proto.Message): def raw_page(self): return self - training_pipelines = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_training_pipeline.TrainingPipeline, + training_pipelines = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py index b04954f602..cede653bd6 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_state.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_state.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'PipelineState', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"PipelineState",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index 872990a5f1..8f8717d675 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -23,12 +23,12 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'PredictRequest', - 'PredictResponse', - 'ExplainRequest', - 'ExplainResponse', + "PredictRequest", + "PredictResponse", + "ExplainRequest", + "ExplainResponse", }, ) @@ -65,13 +65,9 @@ class PredictRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, - message=struct.Value, - ) + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=3, - message=struct.Value, - ) + parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) class PredictResponse(proto.Message): @@ -91,9 +87,7 @@ class PredictResponse(proto.Message): served this prediction. """ - predictions = proto.RepeatedField(proto.MESSAGE, number=1, - message=struct.Value, - ) + predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) deployed_model_id = proto.Field(proto.STRING, number=2) @@ -134,13 +128,9 @@ class ExplainRequest(proto.Message): endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, - message=struct.Value, - ) + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=4, - message=struct.Value, - ) + parameters = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) deployed_model_id = proto.Field(proto.STRING, number=3) @@ -162,8 +152,8 @@ class ExplainResponse(proto.Message): served this explanation. """ - explanations = proto.RepeatedField(proto.MESSAGE, number=1, - message=explanation.Explanation, + explanations = proto.RepeatedField( + proto.MESSAGE, number=1, message=explanation.Explanation, ) deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py index f75416157b..4ac8c6a709 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'SpecialistPool', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"SpecialistPool",}, ) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index 8ee901a444..724f7165a6 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -24,16 +24,16 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'CreateSpecialistPoolRequest', - 'CreateSpecialistPoolOperationMetadata', - 'GetSpecialistPoolRequest', - 'ListSpecialistPoolsRequest', - 'ListSpecialistPoolsResponse', - 'DeleteSpecialistPoolRequest', - 'UpdateSpecialistPoolRequest', - 'UpdateSpecialistPoolOperationMetadata', + "CreateSpecialistPoolRequest", + "CreateSpecialistPoolOperationMetadata", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "DeleteSpecialistPoolRequest", + "UpdateSpecialistPoolRequest", + "UpdateSpecialistPoolOperationMetadata", }, ) @@ -53,8 +53,8 @@ class CreateSpecialistPoolRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - specialist_pool = proto.Field(proto.MESSAGE, number=2, - message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field( + proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, ) @@ -67,8 +67,8 @@ class CreateSpecialistPoolOperationMetadata(proto.Message): The operation generic information. """ - generic_metadata = proto.Field(proto.MESSAGE, number=1, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) @@ -114,9 +114,7 @@ class ListSpecialistPoolsRequest(proto.Message): page_token = proto.Field(proto.STRING, number=3) - read_mask = proto.Field(proto.MESSAGE, number=4, - message=field_mask.FieldMask, - ) + read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) class ListSpecialistPoolsResponse(proto.Message): @@ -135,8 +133,8 @@ class ListSpecialistPoolsResponse(proto.Message): def raw_page(self): return self - specialist_pools = proto.RepeatedField(proto.MESSAGE, number=1, - message=gca_specialist_pool.SpecialistPool, + specialist_pools = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) next_page_token = proto.Field(proto.STRING, number=2) @@ -176,13 +174,11 @@ class UpdateSpecialistPoolRequest(proto.Message): resource. """ - specialist_pool = proto.Field(proto.MESSAGE, number=1, - message=gca_specialist_pool.SpecialistPool, + specialist_pool = proto.Field( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) - update_mask = proto.Field(proto.MESSAGE, number=2, - message=field_mask.FieldMask, - ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) class UpdateSpecialistPoolOperationMetadata(proto.Message): @@ -201,8 +197,8 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): specialist_pool = proto.Field(proto.STRING, number=1) - generic_metadata = proto.Field(proto.MESSAGE, number=2, - message=operation.GenericOperationMetadata, + generic_metadata = proto.Field( + proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index b60e344617..5d053e7162 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -23,12 +23,8 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'Trial', - 'StudySpec', - 'Measurement', - }, + package="google.cloud.aiplatform.v1beta1", + manifest={"Trial", "StudySpec", "Measurement",}, ) @@ -58,6 +54,7 @@ class Trial(proto.Message): Trial. It's set for a HyperparameterTuningJob's Trial. """ + class State(proto.Enum): r"""Describes a Trial state.""" STATE_UNSPECIFIED = 0 @@ -85,31 +82,19 @@ class Parameter(proto.Message): parameter_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.MESSAGE, number=2, - message=struct.Value, - ) + value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) id = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=3, - enum=State, - ) + state = proto.Field(proto.ENUM, number=3, enum=State,) - parameters = proto.RepeatedField(proto.MESSAGE, number=4, - message=Parameter, - ) + parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) - final_measurement = proto.Field(proto.MESSAGE, number=5, - message='Measurement', - ) + final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) - start_time = proto.Field(proto.MESSAGE, number=7, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) custom_job = proto.Field(proto.STRING, number=11) @@ -125,6 +110,7 @@ class StudySpec(proto.Message): algorithm (~.study.StudySpec.Algorithm): The search algorithm specified for the Study. """ + class Algorithm(proto.Enum): r"""The available search algorithms for the Study.""" ALGORITHM_UNSPECIFIED = 0 @@ -143,6 +129,7 @@ class MetricSpec(proto.Message): Required. The optimization goal of the metric. """ + class GoalType(proto.Enum): r"""The available types of optimization goals.""" GOAL_TYPE_UNSPECIFIED = 0 @@ -151,9 +138,7 @@ class GoalType(proto.Enum): metric_id = proto.Field(proto.STRING, number=1) - goal = proto.Field(proto.ENUM, number=2, - enum='StudySpec.MetricSpec.GoalType', - ) + goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) class ParameterSpec(proto.Message): r"""Represents a single parameter to optimize. @@ -175,6 +160,7 @@ class ParameterSpec(proto.Message): How the parameter should be scaled. Leave unset for ``CATEGORICAL`` parameters. """ + class ScaleType(proto.Enum): r"""The type of scaling that should be applied to this parameter.""" SCALE_TYPE_UNSPECIFIED = 0 @@ -239,39 +225,45 @@ class DiscreteValueSpec(proto.Message): values = proto.RepeatedField(proto.DOUBLE, number=1) - double_value_spec = proto.Field(proto.MESSAGE, number=2, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.DoubleValueSpec', + double_value_spec = proto.Field( + proto.MESSAGE, + number=2, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DoubleValueSpec", ) - integer_value_spec = proto.Field(proto.MESSAGE, number=3, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.IntegerValueSpec', + integer_value_spec = proto.Field( + proto.MESSAGE, + number=3, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.IntegerValueSpec", ) - categorical_value_spec = proto.Field(proto.MESSAGE, number=4, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.CategoricalValueSpec', + categorical_value_spec = proto.Field( + proto.MESSAGE, + number=4, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.CategoricalValueSpec", ) - discrete_value_spec = proto.Field(proto.MESSAGE, number=5, oneof='parameter_value_spec', - message='StudySpec.ParameterSpec.DiscreteValueSpec', + discrete_value_spec = proto.Field( + proto.MESSAGE, + number=5, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DiscreteValueSpec", ) parameter_id = proto.Field(proto.STRING, number=1) - scale_type = proto.Field(proto.ENUM, number=6, - enum='StudySpec.ParameterSpec.ScaleType', + scale_type = proto.Field( + proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", ) - metrics = proto.RepeatedField(proto.MESSAGE, number=1, - message=MetricSpec, - ) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) - parameters = proto.RepeatedField(proto.MESSAGE, number=2, - message=ParameterSpec, - ) + parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) - algorithm = proto.Field(proto.ENUM, number=3, - enum=Algorithm, - ) + algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) class Measurement(proto.Message): @@ -289,6 +281,7 @@ class Measurement(proto.Message): evaluating the objective functions using suggested Parameter values. """ + class Metric(proto.Message): r"""A message representing a metric in the measurement. @@ -307,9 +300,7 @@ class Metric(proto.Message): step_count = proto.Field(proto.INT64, number=2) - metrics = proto.RepeatedField(proto.MESSAGE, number=3, - message=Metric, - ) + metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index 0729605971..86d6168b8e 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -27,14 +27,14 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', + package="google.cloud.aiplatform.v1beta1", manifest={ - 'TrainingPipeline', - 'InputDataConfig', - 'FractionSplit', - 'FilterSplit', - 'PredefinedSplit', - 'TimestampSplit', + "TrainingPipeline", + "InputDataConfig", + "FractionSplit", + "FilterSplit", + "PredefinedSplit", + "TimestampSplit", }, ) @@ -146,47 +146,27 @@ class TrainingPipeline(proto.Message): display_name = proto.Field(proto.STRING, number=2) - input_data_config = proto.Field(proto.MESSAGE, number=3, - message='InputDataConfig', - ) + input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) training_task_definition = proto.Field(proto.STRING, number=4) - training_task_inputs = proto.Field(proto.MESSAGE, number=5, - message=struct.Value, - ) + training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - training_task_metadata = proto.Field(proto.MESSAGE, number=6, - message=struct.Value, - ) + training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - model_to_upload = proto.Field(proto.MESSAGE, number=7, - message=model.Model, - ) + model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) - state = proto.Field(proto.ENUM, number=9, - enum=pipeline_state.PipelineState, - ) + state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) - error = proto.Field(proto.MESSAGE, number=10, - message=status.Status, - ) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) - create_time = proto.Field(proto.MESSAGE, number=11, - message=timestamp.Timestamp, - ) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, - message=timestamp.Timestamp, - ) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, - message=timestamp.Timestamp, - ) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, - message=timestamp.Timestamp, - ) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) labels = proto.MapField(proto.STRING, proto.STRING, number=15) @@ -284,24 +264,24 @@ class InputDataConfig(proto.Message): ``annotation_schema_uri``. """ - fraction_split = proto.Field(proto.MESSAGE, number=2, oneof='split', - message='FractionSplit', + fraction_split = proto.Field( + proto.MESSAGE, number=2, oneof="split", message="FractionSplit", ) - filter_split = proto.Field(proto.MESSAGE, number=3, oneof='split', - message='FilterSplit', + filter_split = proto.Field( + proto.MESSAGE, number=3, oneof="split", message="FilterSplit", ) - predefined_split = proto.Field(proto.MESSAGE, number=4, oneof='split', - message='PredefinedSplit', + predefined_split = proto.Field( + proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", ) - timestamp_split = proto.Field(proto.MESSAGE, number=5, oneof='split', - message='TimestampSplit', + timestamp_split = proto.Field( + proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", ) - gcs_destination = proto.Field(proto.MESSAGE, number=8, oneof='destination', - message=io.GcsDestination, + gcs_destination = proto.Field( + proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, ) dataset_id = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index 6e54a37598..710e4a6d16 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -19,10 +19,7 @@ __protobuf__ = proto.module( - package='google.cloud.aiplatform.v1beta1', - manifest={ - 'UserActionReference', - }, + package="google.cloud.aiplatform.v1beta1", manifest={"UserActionReference",}, ) @@ -47,9 +44,9 @@ class UserActionReference(proto.Message): "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1, oneof='reference') + operation = proto.Field(proto.STRING, number=1, oneof="reference") - data_labeling_job = proto.Field(proto.STRING, number=2, oneof='reference') + data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") method = proto.Field(proto.STRING, number=3) diff --git a/synth.metadata b/synth.metadata index a999634172..ec41255fd8 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,7 +4,7 @@ "git": { "name": ".", "remote": "https://github.com/dizcology/python-aiplatform.git", - "sha": "0fc50abf7571c93a7810506e73c05a72b9f6efc0" + "sha": "288035dd0612b35204273d09a2b3dbbba9fe5e2c" } }, { diff --git a/synth.py b/synth.py index ce8c810d80..b370daadeb 100644 --- a/synth.py +++ b/synth.py @@ -40,7 +40,6 @@ s.move( library, excludes=[ - ".kokoro", "setup.py", "README.rst", "docs/index.rst", @@ -53,9 +52,6 @@ # Patch the library # ---------------------------------------------------------------------------- -# https://github.com/googleapis/gapic-generator-python/issues/336 -s.replace("**/client.py", " operation.from_gapic", " ga_operation.from_gapic") - s.replace( "**/client.py", "client_options: ClientOptions = ", @@ -69,6 +65,13 @@ "request.instances.extend(instances)", ) +# https://github.com/googleapis/gapic-generator-python/issues/672 +s.replace( + "google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py", + "request.traffic_split.extend\(traffic_split\)", + "request.traffic_split = traffic_split", +) + # post processing to fix the generated reference doc from synthtool import transforms as st import re @@ -125,7 +128,11 @@ templated_files = common.py_library(cov_level=99, microgenerator=True) s.move( - templated_files, excludes=[".coveragerc"] + templated_files, + excludes=[ + ".coveragerc", + ".kokoro/samples/**" + ] ) # the microgenerator has a good coveragerc file # Don't treat docs warnings as errors diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index 002b1afc4e..8b4313034b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceClient +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceClient, +) from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.services.dataset_service import transports from google.cloud.aiplatform_v1beta1.types import annotation @@ -62,7 +66,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -73,17 +81,35 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [DatasetServiceClient, DatasetServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] +) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -91,7 +117,7 @@ def test_dataset_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_client_get_transport_class(): @@ -102,29 +128,44 @@ def test_dataset_service_client_get_transport_class(): assert transport == transports.DatasetServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) -def test_dataset_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -140,7 +181,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -156,7 +197,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -176,13 +217,15 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -195,26 +238,56 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_dataset_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -237,11 +310,21 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -251,7 +334,9 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -266,10 +351,17 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -284,16 +376,23 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -306,16 +405,24 @@ def test_dataset_service_client_client_options_scopes(client_class, transport_cl client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -330,10 +437,12 @@ def test_dataset_service_client_client_options_credentials_file(client_class, tr def test_dataset_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = DatasetServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -346,10 +455,11 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): +def test_create_dataset( + transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -357,11 +467,9 @@ def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.Cr request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_dataset(request) @@ -380,10 +488,9 @@ def test_create_dataset_from_dict(): @pytest.mark.asyncio -async def test_create_dataset_async(transport: str = 'grpc_asyncio'): +async def test_create_dataset_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -391,12 +498,10 @@ async def test_create_dataset_async(transport: str = 'grpc_asyncio'): request = dataset_service.CreateDatasetRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_dataset(request) @@ -412,20 +517,16 @@ async def test_create_dataset_async(transport: str = 'grpc_asyncio'): def test_create_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_dataset(request) @@ -436,28 +537,23 @@ def test_create_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_dataset(request) @@ -468,29 +564,21 @@ async def test_create_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -498,47 +586,40 @@ def test_create_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") def test_create_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) @pytest.mark.asyncio async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -546,31 +627,30 @@ async def test_create_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") @pytest.mark.asyncio async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) -def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): +def test_get_dataset( + transport: str = "grpc", request_type=dataset_service.GetDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -578,19 +658,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.get_dataset(request) @@ -604,13 +678,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_dataset_from_dict(): @@ -618,10 +692,9 @@ def test_get_dataset_from_dict(): @pytest.mark.asyncio -async def test_get_dataset_async(transport: str = 'grpc_asyncio'): +async def test_get_dataset_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -629,16 +702,16 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio'): request = dataset_service.GetDatasetRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.get_dataset(request) @@ -651,29 +724,25 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = dataset.Dataset() client.get_dataset(request) @@ -685,27 +754,20 @@ def test_get_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) await client.get_dataset(request) @@ -717,99 +779,79 @@ async def test_get_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset( - name='name_value', - ) + client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset( - name='name_value', - ) + response = await client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) -def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): +def test_update_dataset( + transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -817,19 +859,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.update_dataset(request) @@ -843,13 +879,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_dataset_from_dict(): @@ -857,10 +893,9 @@ def test_update_dataset_from_dict(): @pytest.mark.asyncio -async def test_update_dataset_async(transport: str = 'grpc_asyncio'): +async def test_update_dataset_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -868,16 +903,16 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio'): request = dataset_service.UpdateDatasetRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.update_dataset(request) @@ -890,29 +925,25 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = gca_dataset.Dataset() client.update_dataset(request) @@ -924,27 +955,22 @@ def test_update_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) await client.update_dataset(request) @@ -956,29 +982,24 @@ async def test_update_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] def test_update_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -986,36 +1007,30 @@ def test_update_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() @@ -1023,8 +1038,8 @@ async def test_update_dataset_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1032,31 +1047,30 @@ async def test_update_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): +def test_list_datasets( + transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1064,13 +1078,10 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_datasets(request) @@ -1084,7 +1095,7 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_datasets_from_dict(): @@ -1092,10 +1103,9 @@ def test_list_datasets_from_dict(): @pytest.mark.asyncio -async def test_list_datasets_async(transport: str = 'grpc_asyncio'): +async def test_list_datasets_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1103,13 +1113,13 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio'): request = dataset_service.ListDatasetsRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_datasets(request) @@ -1122,23 +1132,19 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_datasets_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: call.return_value = dataset_service.ListDatasetsResponse() client.list_datasets(request) @@ -1150,28 +1156,23 @@ def test_list_datasets_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) await client.list_datasets(request) @@ -1182,138 +1183,100 @@ async def test_list_datasets_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_datasets_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_datasets( - parent='parent_value', - ) + client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_datasets_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_datasets( - parent='parent_value', - ) + response = await client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) def test_list_datasets_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_datasets(request={}) @@ -1321,147 +1284,102 @@ def test_list_datasets_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in results) + assert all(isinstance(i, dataset.Dataset) for i in results) + def test_list_datasets_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in responses) + assert all(isinstance(i, dataset.Dataset) for i in responses) + @pytest.mark.asyncio async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_datasets(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): +def test_delete_dataset( + transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1469,11 +1387,9 @@ def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.De request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_dataset(request) @@ -1492,10 +1408,9 @@ def test_delete_dataset_from_dict(): @pytest.mark.asyncio -async def test_delete_dataset_async(transport: str = 'grpc_asyncio'): +async def test_delete_dataset_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1503,12 +1418,10 @@ async def test_delete_dataset_async(transport: str = 'grpc_asyncio'): request = dataset_service.DeleteDatasetRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_dataset(request) @@ -1524,20 +1437,16 @@ async def test_delete_dataset_async(transport: str = 'grpc_asyncio'): def test_delete_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_dataset(request) @@ -1548,28 +1457,23 @@ def test_delete_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_dataset(request) @@ -1580,101 +1484,81 @@ async def test_delete_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset( - name='name_value', - ) + client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset( - name='name_value', - ) + response = await client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) -def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): +def test_import_data( + transport: str = "grpc", request_type=dataset_service.ImportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1682,11 +1566,9 @@ def test_import_data(transport: str = 'grpc', request_type=dataset_service.Impor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.import_data(request) @@ -1705,10 +1587,9 @@ def test_import_data_from_dict(): @pytest.mark.asyncio -async def test_import_data_async(transport: str = 'grpc_asyncio'): +async def test_import_data_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1716,12 +1597,10 @@ async def test_import_data_async(transport: str = 'grpc_asyncio'): request = dataset_service.ImportDataRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.import_data(request) @@ -1737,20 +1616,16 @@ async def test_import_data_async(transport: str = 'grpc_asyncio'): def test_import_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.import_data(request) @@ -1761,29 +1636,24 @@ def test_import_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) - + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.import_data(request) # Establish that the underlying gRPC stub method was called. @@ -1793,29 +1663,24 @@ async def test_import_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_import_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -1823,47 +1688,47 @@ def test_import_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] def test_import_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) @pytest.mark.asyncio async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -1871,31 +1736,34 @@ async def test_import_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] @pytest.mark.asyncio async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) -def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): +def test_export_data( + transport: str = "grpc", request_type=dataset_service.ExportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1903,11 +1771,9 @@ def test_export_data(transport: str = 'grpc', request_type=dataset_service.Expor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_data(request) @@ -1926,10 +1792,9 @@ def test_export_data_from_dict(): @pytest.mark.asyncio -async def test_export_data_async(transport: str = 'grpc_asyncio'): +async def test_export_data_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1937,12 +1802,10 @@ async def test_export_data_async(transport: str = 'grpc_asyncio'): request = dataset_service.ExportDataRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_data(request) @@ -1958,20 +1821,16 @@ async def test_export_data_async(transport: str = 'grpc_asyncio'): def test_export_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_data(request) @@ -1982,28 +1841,23 @@ def test_export_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_data(request) @@ -2014,29 +1868,26 @@ async def test_export_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_export_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2044,47 +1895,53 @@ def test_export_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) def test_export_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) @pytest.mark.asyncio async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2092,31 +1949,38 @@ async def test_export_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) @pytest.mark.asyncio async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) -def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): +def test_list_data_items( + transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2124,13 +1988,10 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_items(request) @@ -2144,7 +2005,7 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_items_from_dict(): @@ -2152,10 +2013,9 @@ def test_list_data_items_from_dict(): @pytest.mark.asyncio -async def test_list_data_items_async(transport: str = 'grpc_asyncio'): +async def test_list_data_items_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2163,13 +2023,13 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio'): request = dataset_service.ListDataItemsRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_items(request) @@ -2182,23 +2042,19 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_items_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: call.return_value = dataset_service.ListDataItemsResponse() client.list_data_items(request) @@ -2210,28 +2066,23 @@ def test_list_data_items_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) await client.list_data_items(request) @@ -2242,104 +2093,81 @@ async def test_list_data_items_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_items_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items( - parent='parent_value', - ) + client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_items_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items( - parent='parent_value', - ) + response = await client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) def test_list_data_items_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2348,32 +2176,23 @@ def test_list_data_items_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_items(request={}) @@ -2381,18 +2200,14 @@ def test_list_data_items_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in results) + assert all(isinstance(i, data_item.DataItem) for i in results) + def test_list_data_items_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2401,40 +2216,32 @@ def test_list_data_items_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2443,46 +2250,37 @@ async def test_list_data_items_async_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in responses) + assert all(isinstance(i, data_item.DataItem) for i in responses) + @pytest.mark.asyncio async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2491,37 +2289,31 @@ async def test_list_data_items_async_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_data_items(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): +def test_get_annotation_spec( + transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2530,16 +2322,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - + name="name_value", display_name="display_name_value", etag="etag_value", ) response = client.get_annotation_spec(request) @@ -2553,11 +2340,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_annotation_spec_from_dict(): @@ -2565,10 +2352,9 @@ def test_get_annotation_spec_from_dict(): @pytest.mark.asyncio -async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio'): +async def test_get_annotation_spec_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2577,14 +2363,14 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( - name='name_value', - display_name='display_name_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + ) response = await client.get_annotation_spec(request) @@ -2597,27 +2383,25 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: call.return_value = annotation_spec.AnnotationSpec() client.get_annotation_spec(request) @@ -2629,28 +2413,25 @@ def test_get_annotation_spec_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) await client.get_annotation_spec(request) @@ -2661,99 +2442,85 @@ async def test_get_annotation_spec_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_annotation_spec_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_annotation_spec( - name='name_value', - ) + client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_annotation_spec( - name='name_value', - ) + response = await client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) -def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): +def test_list_annotations( + transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2761,13 +2528,10 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_annotations(request) @@ -2781,7 +2545,7 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_annotations_from_dict(): @@ -2789,10 +2553,9 @@ def test_list_annotations_from_dict(): @pytest.mark.asyncio -async def test_list_annotations_async(transport: str = 'grpc_asyncio'): +async def test_list_annotations_async(transport: str = "grpc_asyncio"): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2800,13 +2563,13 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio'): request = dataset_service.ListAnnotationsRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_annotations(request) @@ -2819,23 +2582,19 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_annotations_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: call.return_value = dataset_service.ListAnnotationsResponse() client.list_annotations(request) @@ -2847,28 +2606,23 @@ def test_list_annotations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) await client.list_annotations(request) @@ -2879,104 +2633,81 @@ async def test_list_annotations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_annotations_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations( - parent='parent_value', - ) + client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_annotations_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations( - parent='parent_value', - ) + response = await client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) def test_list_annotations_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2985,32 +2716,23 @@ def test_list_annotations_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_annotations(request={}) @@ -3018,18 +2740,14 @@ def test_list_annotations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in results) + assert all(isinstance(i, annotation.Annotation) for i in results) + def test_list_annotations_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3038,40 +2756,32 @@ def test_list_annotations_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3080,46 +2790,37 @@ async def test_list_annotations_async_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in responses) + assert all(isinstance(i, annotation.Annotation) for i in responses) + @pytest.mark.asyncio async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3128,30 +2829,23 @@ async def test_list_annotations_async_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_annotations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3162,8 +2856,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3182,8 +2875,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3211,13 +2903,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3225,13 +2920,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.DatasetServiceGrpcTransport, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) def test_dataset_service_base_transport_error(): @@ -3239,13 +2929,15 @@ def test_dataset_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_dataset_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3254,17 +2946,17 @@ def test_dataset_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_dataset', - 'get_dataset', - 'update_dataset', - 'list_datasets', - 'delete_dataset', - 'import_data', - 'export_data', - 'list_data_items', - 'get_annotation_spec', - 'list_annotations', - ) + "create_dataset", + "get_dataset", + "update_dataset", + "list_datasets", + "delete_dataset", + "import_data", + "export_data", + "list_data_items", + "get_annotation_spec", + "list_annotations", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3277,23 +2969,28 @@ def test_dataset_service_base_transport(): def test_dataset_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_dataset_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport() @@ -3302,11 +2999,11 @@ def test_dataset_service_base_transport_with_adc(): def test_dataset_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) DatasetServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3314,60 +3011,75 @@ def test_dataset_service_auth_adc(): def test_dataset_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_host_with_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_dataset_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3376,7 +3088,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3392,26 +3104,30 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) -def test_dataset_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3428,9 +3144,7 @@ def test_dataset_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3439,16 +3153,12 @@ def test_dataset_service_transport_channel_mtls_with_adc( def test_dataset_service_grpc_lro_client(): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3456,20 +3166,17 @@ def test_dataset_service_grpc_lro_client(): def test_dataset_service_grpc_lro_async_client(): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_annotation_path(): project = "squid" location = "clam" @@ -3477,19 +3184,26 @@ def test_annotation_path(): data_item = "octopus" annotation = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) - actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + actual = DatasetServiceClient.annotation_path( + project, location, dataset, data_item, annotation + ) assert expected == actual def test_parse_annotation_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - "data_item": "winkle", - "annotation": "nautilus", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", } path = DatasetServiceClient.annotation_path(**expected) @@ -3497,24 +3211,31 @@ def test_parse_annotation_path(): actual = DatasetServiceClient.parse_annotation_path(path) assert expected == actual + def test_annotation_spec_path(): project = "scallop" location = "abalone" dataset = "squid" annotation_spec = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) - actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + actual = DatasetServiceClient.annotation_spec_path( + project, location, dataset, annotation_spec + ) assert expected == actual def test_parse_annotation_spec_path(): expected = { - "project": "whelk", - "location": "octopus", - "dataset": "oyster", - "annotation_spec": "nudibranch", - + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", } path = DatasetServiceClient.annotation_spec_path(**expected) @@ -3522,24 +3243,26 @@ def test_parse_annotation_spec_path(): actual = DatasetServiceClient.parse_annotation_spec_path(path) assert expected == actual + def test_data_item_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" data_item = "nautilus" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) assert expected == actual def test_parse_data_item_path(): expected = { - "project": "scallop", - "location": "abalone", - "dataset": "squid", - "data_item": "clam", - + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", } path = DatasetServiceClient.data_item_path(**expected) @@ -3547,22 +3270,24 @@ def test_parse_data_item_path(): actual = DatasetServiceClient.parse_data_item_path(path) assert expected == actual + def test_dataset_path(): project = "whelk" location = "octopus" dataset = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = DatasetServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", } path = DatasetServiceClient.dataset_path(**expected) @@ -3570,18 +3295,20 @@ def test_parse_dataset_path(): actual = DatasetServiceClient.parse_dataset_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = DatasetServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } path = DatasetServiceClient.common_billing_account_path(**expected) @@ -3589,18 +3316,18 @@ def test_parse_common_billing_account_path(): actual = DatasetServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = DatasetServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = DatasetServiceClient.common_folder_path(**expected) @@ -3608,18 +3335,18 @@ def test_parse_common_folder_path(): actual = DatasetServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = DatasetServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = DatasetServiceClient.common_organization_path(**expected) @@ -3627,18 +3354,18 @@ def test_parse_common_organization_path(): actual = DatasetServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = DatasetServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = DatasetServiceClient.common_project_path(**expected) @@ -3646,20 +3373,22 @@ def test_parse_common_project_path(): actual = DatasetServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = DatasetServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = DatasetServiceClient.common_location_path(**expected) @@ -3671,17 +3400,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = DatasetServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index 3d638675f1..45895347ec 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceClient +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceClient, +) from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type @@ -62,7 +66,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -73,17 +81,35 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [EndpointServiceClient, EndpointServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] +) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -91,7 +117,7 @@ def test_endpoint_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_client_get_transport_class(): @@ -102,29 +128,44 @@ def test_endpoint_service_client_get_transport_class(): assert transport == transports.EndpointServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) -def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -140,7 +181,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -156,7 +197,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -176,13 +217,15 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -195,26 +238,66 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "true", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "false", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_endpoint_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -237,11 +320,21 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -251,7 +344,9 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -266,10 +361,17 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -284,16 +386,23 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -306,16 +415,24 @@ def test_endpoint_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -330,10 +447,12 @@ def test_endpoint_service_client_client_options_credentials_file(client_class, t def test_endpoint_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = EndpointServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -346,10 +465,11 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): +def test_create_endpoint( + transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -357,11 +477,9 @@ def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_endpoint(request) @@ -380,10 +498,9 @@ def test_create_endpoint_from_dict(): @pytest.mark.asyncio -async def test_create_endpoint_async(transport: str = 'grpc_asyncio'): +async def test_create_endpoint_async(transport: str = "grpc_asyncio"): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -391,12 +508,10 @@ async def test_create_endpoint_async(transport: str = 'grpc_asyncio'): request = endpoint_service.CreateEndpointRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_endpoint(request) @@ -412,20 +527,16 @@ async def test_create_endpoint_async(transport: str = 'grpc_asyncio'): def test_create_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_endpoint(request) @@ -436,28 +547,23 @@ def test_create_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_endpoint(request) @@ -468,29 +574,21 @@ async def test_create_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -498,47 +596,40 @@ def test_create_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") def test_create_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) @pytest.mark.asyncio async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -546,31 +637,30 @@ async def test_create_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") @pytest.mark.asyncio async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) -def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): +def test_get_endpoint( + transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -578,19 +668,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.get_endpoint(request) @@ -604,13 +688,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_endpoint_from_dict(): @@ -618,10 +702,9 @@ def test_get_endpoint_from_dict(): @pytest.mark.asyncio -async def test_get_endpoint_async(transport: str = 'grpc_asyncio'): +async def test_get_endpoint_async(transport: str = "grpc_asyncio"): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -629,16 +712,16 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio'): request = endpoint_service.GetEndpointRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.get_endpoint(request) @@ -651,29 +734,25 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = endpoint.Endpoint() client.get_endpoint(request) @@ -685,27 +764,20 @@ def test_get_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) await client.get_endpoint(request) @@ -717,99 +789,79 @@ async def test_get_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_endpoint( - name='name_value', - ) + client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_endpoint( - name='name_value', - ) + response = await client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) -def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): +def test_list_endpoints( + transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -817,13 +869,10 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_endpoints(request) @@ -837,7 +886,7 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_endpoints_from_dict(): @@ -845,10 +894,9 @@ def test_list_endpoints_from_dict(): @pytest.mark.asyncio -async def test_list_endpoints_async(transport: str = 'grpc_asyncio'): +async def test_list_endpoints_async(transport: str = "grpc_asyncio"): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -856,13 +904,13 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio'): request = endpoint_service.ListEndpointsRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_endpoints(request) @@ -875,23 +923,19 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_endpoints_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: call.return_value = endpoint_service.ListEndpointsResponse() client.list_endpoints(request) @@ -903,28 +947,23 @@ def test_list_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) await client.list_endpoints(request) @@ -935,104 +974,81 @@ async def test_list_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_endpoints_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_endpoints( - parent='parent_value', - ) + client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_endpoints_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_endpoints( - parent='parent_value', - ) + response = await client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) def test_list_endpoints_pager(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1041,32 +1057,23 @@ def test_list_endpoints_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_endpoints(request={}) @@ -1074,18 +1081,14 @@ def test_list_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in results) + assert all(isinstance(i, endpoint.Endpoint) for i in results) + def test_list_endpoints_pages(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1094,40 +1097,32 @@ def test_list_endpoints_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1136,46 +1131,37 @@ async def test_list_endpoints_async_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in responses) + assert all(isinstance(i, endpoint.Endpoint) for i in responses) + @pytest.mark.asyncio async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1184,37 +1170,31 @@ async def test_list_endpoints_async_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): +def test_update_endpoint( + transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1222,19 +1202,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.update_endpoint(request) @@ -1248,13 +1222,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_endpoint_from_dict(): @@ -1262,10 +1236,9 @@ def test_update_endpoint_from_dict(): @pytest.mark.asyncio -async def test_update_endpoint_async(transport: str = 'grpc_asyncio'): +async def test_update_endpoint_async(transport: str = "grpc_asyncio"): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1273,16 +1246,16 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio'): request = endpoint_service.UpdateEndpointRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.update_endpoint(request) @@ -1295,29 +1268,25 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: call.return_value = gca_endpoint.Endpoint() client.update_endpoint(request) @@ -1329,28 +1298,25 @@ def test_update_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) await client.update_endpoint(request) @@ -1361,29 +1327,24 @@ async def test_update_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] def test_update_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1391,45 +1352,41 @@ def test_update_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1437,31 +1394,30 @@ async def test_update_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): +def test_delete_endpoint( + transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1469,11 +1425,9 @@ def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_endpoint(request) @@ -1492,10 +1446,9 @@ def test_delete_endpoint_from_dict(): @pytest.mark.asyncio -async def test_delete_endpoint_async(transport: str = 'grpc_asyncio'): +async def test_delete_endpoint_async(transport: str = "grpc_asyncio"): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1503,12 +1456,10 @@ async def test_delete_endpoint_async(transport: str = 'grpc_asyncio'): request = endpoint_service.DeleteEndpointRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_endpoint(request) @@ -1524,20 +1475,16 @@ async def test_delete_endpoint_async(transport: str = 'grpc_asyncio'): def test_delete_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_endpoint(request) @@ -1548,28 +1495,23 @@ def test_delete_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_endpoint(request) @@ -1580,101 +1522,81 @@ async def test_delete_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_endpoint( - name='name_value', - ) + client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_endpoint( - name='name_value', - ) + response = await client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) -def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): +def test_deploy_model( + transport: str = "grpc", request_type=endpoint_service.DeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1682,11 +1604,9 @@ def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.Dep request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.deploy_model(request) @@ -1705,10 +1625,9 @@ def test_deploy_model_from_dict(): @pytest.mark.asyncio -async def test_deploy_model_async(transport: str = 'grpc_asyncio'): +async def test_deploy_model_async(transport: str = "grpc_asyncio"): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1716,12 +1635,10 @@ async def test_deploy_model_async(transport: str = 'grpc_asyncio'): request = endpoint_service.DeployModelRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.deploy_model(request) @@ -1737,20 +1654,16 @@ async def test_deploy_model_async(transport: str = 'grpc_asyncio'): def test_deploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.deploy_model(request) @@ -1761,28 +1674,23 @@ def test_deploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.deploy_model(request) @@ -1793,30 +1701,29 @@ async def test_deploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_deploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -1824,51 +1731,63 @@ def test_deploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_deploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -1876,34 +1795,45 @@ async def test_deploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) -def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): +def test_undeploy_model( + transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1911,11 +1841,9 @@ def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.U request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.undeploy_model(request) @@ -1934,10 +1862,9 @@ def test_undeploy_model_from_dict(): @pytest.mark.asyncio -async def test_undeploy_model_async(transport: str = 'grpc_asyncio'): +async def test_undeploy_model_async(transport: str = "grpc_asyncio"): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1945,12 +1872,10 @@ async def test_undeploy_model_async(transport: str = 'grpc_asyncio'): request = endpoint_service.UndeployModelRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.undeploy_model(request) @@ -1966,20 +1891,16 @@ async def test_undeploy_model_async(transport: str = 'grpc_asyncio'): def test_undeploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.undeploy_model(request) @@ -1990,28 +1911,23 @@ def test_undeploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.undeploy_model(request) @@ -2022,30 +1938,23 @@ async def test_undeploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_undeploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2053,51 +1962,45 @@ def test_undeploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_undeploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2105,27 +2008,25 @@ async def test_undeploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @@ -2136,8 +2037,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2156,8 +2056,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -2185,13 +2084,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2199,13 +2101,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.EndpointServiceGrpcTransport, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) def test_endpoint_service_base_transport_error(): @@ -2213,13 +2110,15 @@ def test_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2228,14 +2127,14 @@ def test_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_endpoint', - 'get_endpoint', - 'list_endpoints', - 'update_endpoint', - 'delete_endpoint', - 'deploy_model', - 'undeploy_model', - ) + "create_endpoint", + "get_endpoint", + "list_endpoints", + "update_endpoint", + "delete_endpoint", + "deploy_model", + "undeploy_model", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2248,23 +2147,28 @@ def test_endpoint_service_base_transport(): def test_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport() @@ -2273,11 +2177,11 @@ def test_endpoint_service_base_transport_with_adc(): def test_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) EndpointServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -2285,60 +2189,75 @@ def test_endpoint_service_auth_adc(): def test_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_host_with_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_endpoint_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2347,7 +2266,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2363,26 +2282,30 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) -def test_endpoint_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2399,9 +2322,7 @@ def test_endpoint_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -2410,16 +2331,12 @@ def test_endpoint_service_transport_channel_mtls_with_adc( def test_endpoint_service_grpc_lro_client(): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2427,36 +2344,34 @@ def test_endpoint_service_grpc_lro_client(): def test_endpoint_service_grpc_lro_async_client(): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = EndpointServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = EndpointServiceClient.endpoint_path(**expected) @@ -2464,22 +2379,24 @@ def test_parse_endpoint_path(): actual = EndpointServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = EndpointServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = EndpointServiceClient.model_path(**expected) @@ -2487,18 +2404,20 @@ def test_parse_model_path(): actual = EndpointServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = EndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2506,18 +2425,18 @@ def test_parse_common_billing_account_path(): actual = EndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = EndpointServiceClient.common_folder_path(**expected) @@ -2525,18 +2444,18 @@ def test_parse_common_folder_path(): actual = EndpointServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = EndpointServiceClient.common_organization_path(**expected) @@ -2544,18 +2463,18 @@ def test_parse_common_organization_path(): actual = EndpointServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = EndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = EndpointServiceClient.common_project_path(**expected) @@ -2563,20 +2482,22 @@ def test_parse_common_project_path(): actual = EndpointServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = EndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = EndpointServiceClient.common_location_path(**expected) @@ -2588,17 +2509,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = EndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index fa57d58228..19a9fe139c 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -41,14 +41,20 @@ from google.cloud.aiplatform_v1beta1.services.job_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state @@ -75,7 +81,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -86,17 +96,30 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi @pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -104,7 +127,7 @@ def test_job_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_client_get_transport_class(): @@ -115,29 +138,42 @@ def test_job_service_client_get_transport_class(): assert transport == transports.JobServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) -def test_job_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -153,7 +189,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -169,7 +205,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -189,13 +225,15 @@ def test_job_service_client_client_options(client_class, transport_class, transp client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -208,26 +246,54 @@ def test_job_service_client_client_options(client_class, transport_class, transp client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_job_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -250,11 +316,21 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -264,7 +340,9 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -279,10 +357,17 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -297,16 +382,23 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -319,16 +411,24 @@ def test_job_service_client_client_options_scopes(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -343,11 +443,11 @@ def test_job_service_client_client_options_credentials_file(client_class, transp def test_job_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = JobServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -359,10 +459,11 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): +def test_create_custom_job( + transport: str = "grpc", request_type=job_service.CreateCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -371,16 +472,13 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_custom_job(request) @@ -394,9 +492,9 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -406,10 +504,9 @@ def test_create_custom_job_from_dict(): @pytest.mark.asyncio -async def test_create_custom_job_async(transport: str = 'grpc_asyncio'): +async def test_create_custom_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -418,14 +515,16 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_custom_job(request) @@ -438,27 +537,25 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED def test_create_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: call.return_value = gca_custom_job.CustomJob() client.create_custom_job(request) @@ -470,28 +567,25 @@ def test_create_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) await client.create_custom_job(request) @@ -502,29 +596,24 @@ async def test_create_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -532,45 +621,43 @@ def test_create_custom_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") def test_create_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -578,31 +665,30 @@ async def test_create_custom_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") @pytest.mark.asyncio async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) -def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): +def test_get_custom_job( + transport: str = "grpc", request_type=job_service.GetCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -610,17 +696,12 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_custom_job(request) @@ -634,9 +715,9 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -646,10 +727,9 @@ def test_get_custom_job_from_dict(): @pytest.mark.asyncio -async def test_get_custom_job_async(transport: str = 'grpc_asyncio'): +async def test_get_custom_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -657,15 +737,15 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio'): request = job_service.GetCustomJobRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_custom_job(request) @@ -678,27 +758,23 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED def test_get_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: call.return_value = custom_job.CustomJob() client.get_custom_job(request) @@ -710,28 +786,23 @@ def test_get_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) await client.get_custom_job(request) @@ -742,99 +813,81 @@ async def test_get_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_custom_job( - name='name_value', - ) + client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_custom_job( - name='name_value', - ) + response = await client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) -def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): +def test_list_custom_jobs( + transport: str = "grpc", request_type=job_service.ListCustomJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -842,13 +895,10 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_custom_jobs(request) @@ -862,7 +912,7 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_custom_jobs_from_dict(): @@ -870,10 +920,9 @@ def test_list_custom_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio'): +async def test_list_custom_jobs_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -881,13 +930,11 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio'): request = job_service.ListCustomJobsRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_custom_jobs(request) @@ -900,23 +947,19 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_custom_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: call.return_value = job_service.ListCustomJobsResponse() client.list_custom_jobs(request) @@ -928,28 +971,23 @@ def test_list_custom_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) await client.list_custom_jobs(request) @@ -960,104 +998,81 @@ async def test_list_custom_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_custom_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_custom_jobs( - parent='parent_value', - ) + client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_custom_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_custom_jobs( - parent='parent_value', - ) + response = await client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) def test_list_custom_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1066,32 +1081,21 @@ def test_list_custom_jobs_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_custom_jobs(request={}) @@ -1099,18 +1103,14 @@ def test_list_custom_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in results) + assert all(isinstance(i, custom_job.CustomJob) for i in results) + def test_list_custom_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1119,40 +1119,30 @@ def test_list_custom_jobs_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1161,46 +1151,35 @@ async def test_list_custom_jobs_async_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in responses) + assert all(isinstance(i, custom_job.CustomJob) for i in responses) + @pytest.mark.asyncio async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1209,37 +1188,29 @@ async def test_list_custom_jobs_async_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_custom_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): +def test_delete_custom_job( + transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1248,10 +1219,10 @@ def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.Del # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_custom_job(request) @@ -1270,10 +1241,9 @@ def test_delete_custom_job_from_dict(): @pytest.mark.asyncio -async def test_delete_custom_job_async(transport: str = 'grpc_asyncio'): +async def test_delete_custom_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1282,11 +1252,11 @@ async def test_delete_custom_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_custom_job(request) @@ -1302,20 +1272,18 @@ async def test_delete_custom_job_async(transport: str = 'grpc_asyncio'): def test_delete_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_custom_job(request) @@ -1326,28 +1294,25 @@ def test_delete_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_custom_job(request) @@ -1358,101 +1323,85 @@ async def test_delete_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_custom_job( - name='name_value', - ) + client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_custom_job( - name='name_value', - ) + response = await client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) -def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): +def test_cancel_custom_job( + transport: str = "grpc", request_type=job_service.CancelCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1461,8 +1410,8 @@ def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.Can # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1483,10 +1432,9 @@ def test_cancel_custom_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio'): +async def test_cancel_custom_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1495,8 +1443,8 @@ async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1513,19 +1461,17 @@ async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio'): def test_cancel_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = None client.cancel_custom_job(request) @@ -1537,27 +1483,22 @@ def test_cancel_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_custom_job(request) @@ -1569,99 +1510,83 @@ async def test_cancel_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_custom_job( - name='name_value', - ) + client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_custom_job( - name='name_value', - ) + response = await client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) -def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): +def test_create_data_labeling_job( + transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1670,28 +1595,19 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.create_data_labeling_job(request) @@ -1705,23 +1621,23 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_create_data_labeling_job_from_dict(): @@ -1729,10 +1645,9 @@ def test_create_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio'): +async def test_create_data_labeling_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1741,20 +1656,22 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.create_data_labeling_job(request) @@ -1767,39 +1684,37 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_create_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: call.return_value = gca_data_labeling_job.DataLabelingJob() client.create_data_labeling_job(request) @@ -1811,28 +1726,25 @@ def test_create_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) await client.create_data_labeling_job(request) @@ -1843,29 +1755,24 @@ async def test_create_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -1873,45 +1780,45 @@ def test_create_data_labeling_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -1919,31 +1826,32 @@ async def test_create_data_labeling_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) -def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): +def test_get_data_labeling_job( + transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1952,28 +1860,19 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.get_data_labeling_job(request) @@ -1987,23 +1886,23 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_get_data_labeling_job_from_dict(): @@ -2011,10 +1910,9 @@ def test_get_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio'): +async def test_get_data_labeling_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2023,20 +1921,22 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.get_data_labeling_job(request) @@ -2049,39 +1949,37 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_get_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: call.return_value = data_labeling_job.DataLabelingJob() client.get_data_labeling_job(request) @@ -2093,28 +1991,25 @@ def test_get_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) await client.get_data_labeling_job(request) @@ -2125,99 +2020,85 @@ async def test_get_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_data_labeling_job( - name='name_value', - ) + client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_data_labeling_job( - name='name_value', - ) + response = await client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) -def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): +def test_list_data_labeling_jobs( + transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2226,12 +2107,11 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_labeling_jobs(request) @@ -2245,7 +2125,7 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_labeling_jobs_from_dict(): @@ -2253,10 +2133,9 @@ def test_list_data_labeling_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio'): +async def test_list_data_labeling_jobs_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2265,12 +2144,14 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_labeling_jobs(request) @@ -2283,23 +2164,21 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: call.return_value = job_service.ListDataLabelingJobsResponse() client.list_data_labeling_jobs(request) @@ -2311,28 +2190,25 @@ def test_list_data_labeling_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) await client.list_data_labeling_jobs(request) @@ -2343,104 +2219,87 @@ async def test_list_data_labeling_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_labeling_jobs( - parent='parent_value', - ) + client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs( - parent='parent_value', - ) + response = await client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) def test_list_data_labeling_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2449,17 +2308,14 @@ def test_list_data_labeling_jobs_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2472,9 +2328,7 @@ def test_list_data_labeling_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_labeling_jobs(request={}) @@ -2482,18 +2336,16 @@ def test_list_data_labeling_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in results) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) + def test_list_data_labeling_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2502,17 +2354,14 @@ def test_list_data_labeling_jobs_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2523,19 +2372,20 @@ def test_list_data_labeling_jobs_pages(): RuntimeError, ) pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2544,17 +2394,14 @@ async def test_list_data_labeling_jobs_async_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2565,25 +2412,25 @@ async def test_list_data_labeling_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in responses) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2592,17 +2439,14 @@ async def test_list_data_labeling_jobs_async_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2615,14 +2459,15 @@ async def test_list_data_labeling_jobs_async_pages(): pages = [] async for page_ in (await client.list_data_labeling_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): +def test_delete_data_labeling_job( + transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2631,10 +2476,10 @@ def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_data_labeling_job(request) @@ -2653,10 +2498,9 @@ def test_delete_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio'): +async def test_delete_data_labeling_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2665,11 +2509,11 @@ async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_data_labeling_job(request) @@ -2685,20 +2529,18 @@ async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio'): def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_data_labeling_job(request) @@ -2709,28 +2551,25 @@ def test_delete_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_data_labeling_job(request) @@ -2741,101 +2580,85 @@ async def test_delete_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_data_labeling_job( - name='name_value', - ) + client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_data_labeling_job( - name='name_value', - ) + response = await client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) -def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): +def test_cancel_data_labeling_job( + transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2844,8 +2667,8 @@ def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2866,10 +2689,9 @@ def test_cancel_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio'): +async def test_cancel_data_labeling_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2878,8 +2700,8 @@ async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -2896,19 +2718,17 @@ async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio'): def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = None client.cancel_data_labeling_job(request) @@ -2920,27 +2740,22 @@ def test_cancel_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_data_labeling_job(request) @@ -2952,99 +2767,84 @@ async def test_cancel_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_data_labeling_job( - name='name_value', - ) + client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job( - name='name_value', - ) + response = await client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) -def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): +def test_create_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3053,22 +2853,16 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_hyperparameter_tuning_job(request) @@ -3082,9 +2876,9 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3100,10 +2894,9 @@ def test_create_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): +async def test_create_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3112,17 +2905,19 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_hyperparameter_tuning_job(request) @@ -3135,9 +2930,9 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3149,19 +2944,17 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() client.create_hyperparameter_tuning_job(request) @@ -3173,28 +2966,25 @@ def test_create_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.create_hyperparameter_tuning_job(request) @@ -3205,29 +2995,26 @@ async def test_create_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3235,45 +3022,51 @@ def test_create_hyperparameter_tuning_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3281,31 +3074,36 @@ async def test_create_hyperparameter_tuning_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) -def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): +def test_get_hyperparameter_tuning_job( + transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3314,22 +3112,16 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_hyperparameter_tuning_job(request) @@ -3343,9 +3135,9 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3361,10 +3153,9 @@ def test_get_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): +async def test_get_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3373,17 +3164,19 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_hyperparameter_tuning_job(request) @@ -3396,9 +3189,9 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3410,19 +3203,17 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() client.get_hyperparameter_tuning_job(request) @@ -3434,28 +3225,25 @@ def test_get_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.get_hyperparameter_tuning_job(request) @@ -3466,99 +3254,86 @@ async def test_get_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job( - name='name_value', - ) + client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) -def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): +def test_list_hyperparameter_tuning_jobs( + transport: str = "grpc", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3567,12 +3342,11 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_hyperparameter_tuning_jobs(request) @@ -3586,7 +3360,7 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_hyperparameter_tuning_jobs_from_dict(): @@ -3594,10 +3368,9 @@ def test_list_hyperparameter_tuning_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio'): +async def test_list_hyperparameter_tuning_jobs_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3606,12 +3379,14 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_hyperparameter_tuning_jobs(request) @@ -3624,23 +3399,21 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: call.return_value = job_service.ListHyperparameterTuningJobsResponse() client.list_hyperparameter_tuning_jobs(request) @@ -3652,28 +3425,25 @@ def test_list_hyperparameter_tuning_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) await client.list_hyperparameter_tuning_jobs(request) @@ -3684,104 +3454,87 @@ async def test_list_hyperparameter_tuning_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3790,17 +3543,16 @@ def test_list_hyperparameter_tuning_jobs_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3813,9 +3565,7 @@ def test_list_hyperparameter_tuning_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_hyperparameter_tuning_jobs(request={}) @@ -3823,18 +3573,19 @@ def test_list_hyperparameter_tuning_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results + ) + def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3843,17 +3594,16 @@ def test_list_hyperparameter_tuning_jobs_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3864,19 +3614,20 @@ def test_list_hyperparameter_tuning_jobs_pages(): RuntimeError, ) pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3885,17 +3636,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3906,25 +3656,28 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses + ) + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3933,17 +3686,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3954,16 +3706,20 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: + async for page_ in ( + await client.list_hyperparameter_tuning_jobs(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): +def test_delete_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3972,10 +3728,10 @@ def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_hyperparameter_tuning_job(request) @@ -3994,10 +3750,9 @@ def test_delete_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): +async def test_delete_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4006,11 +3761,11 @@ async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_hyperparameter_tuning_job(request) @@ -4026,20 +3781,18 @@ async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asy def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_hyperparameter_tuning_job(request) @@ -4050,28 +3803,25 @@ def test_delete_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_hyperparameter_tuning_job(request) @@ -4082,101 +3832,86 @@ async def test_delete_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job( - name='name_value', - ) + client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) -def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): +def test_cancel_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4185,8 +3920,8 @@ def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -4207,10 +3942,9 @@ def test_cancel_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio'): +async def test_cancel_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4219,8 +3953,8 @@ async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -4237,19 +3971,17 @@ async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asy def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = None client.cancel_hyperparameter_tuning_job(request) @@ -4261,27 +3993,22 @@ def test_cancel_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_hyperparameter_tuning_job(request) @@ -4293,99 +4020,83 @@ async def test_cancel_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) -def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): +def test_create_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4394,20 +4105,15 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_batch_prediction_job(request) @@ -4421,11 +4127,11 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4437,10 +4143,9 @@ def test_create_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio'): +async def test_create_batch_prediction_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4449,16 +4154,18 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_batch_prediction_job(request) @@ -4471,11 +4178,11 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4483,19 +4190,17 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: call.return_value = gca_batch_prediction_job.BatchPredictionJob() client.create_batch_prediction_job(request) @@ -4507,28 +4212,25 @@ def test_create_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) await client.create_batch_prediction_job(request) @@ -4539,29 +4241,26 @@ async def test_create_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -4569,45 +4268,51 @@ def test_create_batch_prediction_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -4615,31 +4320,36 @@ async def test_create_batch_prediction_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) -def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): +def test_get_batch_prediction_job( + transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4648,20 +4358,15 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_batch_prediction_job(request) @@ -4675,11 +4380,11 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4691,10 +4396,9 @@ def test_get_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio'): +async def test_get_batch_prediction_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4703,16 +4407,18 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_batch_prediction_job(request) @@ -4725,11 +4431,11 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4737,19 +4443,17 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio'): def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: call.return_value = batch_prediction_job.BatchPredictionJob() client.get_batch_prediction_job(request) @@ -4761,28 +4465,25 @@ def test_get_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) await client.get_batch_prediction_job(request) @@ -4793,99 +4494,85 @@ async def test_get_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_batch_prediction_job( - name='name_value', - ) + client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_batch_prediction_job( - name='name_value', - ) + response = await client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) -def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): +def test_list_batch_prediction_jobs( + transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4894,12 +4581,11 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_batch_prediction_jobs(request) @@ -4913,7 +4599,7 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_batch_prediction_jobs_from_dict(): @@ -4921,10 +4607,9 @@ def test_list_batch_prediction_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio'): +async def test_list_batch_prediction_jobs_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4933,12 +4618,14 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio') # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_batch_prediction_jobs(request) @@ -4951,23 +4638,21 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio') # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: call.return_value = job_service.ListBatchPredictionJobsResponse() client.list_batch_prediction_jobs(request) @@ -4979,28 +4664,25 @@ def test_list_batch_prediction_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) await client.list_batch_prediction_jobs(request) @@ -5011,104 +4693,87 @@ async def test_list_batch_prediction_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_batch_prediction_jobs( - parent='parent_value', - ) + client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs( - parent='parent_value', - ) + response = await client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5117,17 +4782,14 @@ def test_list_batch_prediction_jobs_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5140,9 +4802,7 @@ def test_list_batch_prediction_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_batch_prediction_jobs(request={}) @@ -5150,18 +4810,18 @@ def test_list_batch_prediction_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in results) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results + ) + def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5170,17 +4830,14 @@ def test_list_batch_prediction_jobs_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5191,19 +4848,20 @@ def test_list_batch_prediction_jobs_pages(): RuntimeError, ) pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5212,17 +4870,14 @@ async def test_list_batch_prediction_jobs_async_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5233,25 +4888,27 @@ async def test_list_batch_prediction_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in responses) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses + ) + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5260,17 +4917,14 @@ async def test_list_batch_prediction_jobs_async_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5283,14 +4937,15 @@ async def test_list_batch_prediction_jobs_async_pages(): pages = [] async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): +def test_delete_batch_prediction_job( + transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5299,10 +4954,10 @@ def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_batch_prediction_job(request) @@ -5321,10 +4976,9 @@ def test_delete_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio'): +async def test_delete_batch_prediction_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5333,11 +4987,11 @@ async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_batch_prediction_job(request) @@ -5353,20 +5007,18 @@ async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio' def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_batch_prediction_job(request) @@ -5377,28 +5029,25 @@ def test_delete_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_batch_prediction_job(request) @@ -5409,101 +5058,85 @@ async def test_delete_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_batch_prediction_job( - name='name_value', - ) + client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job( - name='name_value', - ) + response = await client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) -def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): +def test_cancel_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5512,8 +5145,8 @@ def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -5534,10 +5167,9 @@ def test_cancel_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio'): +async def test_cancel_batch_prediction_job_async(transport: str = "grpc_asyncio"): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5546,8 +5178,8 @@ async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -5564,19 +5196,17 @@ async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio' def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = None client.cancel_batch_prediction_job(request) @@ -5588,27 +5218,22 @@ def test_cancel_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_batch_prediction_job(request) @@ -5620,92 +5245,75 @@ async def test_cancel_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_batch_prediction_job( - name='name_value', - ) + client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job( - name='name_value', - ) + response = await client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) @@ -5716,8 +5324,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -5736,8 +5343,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -5765,13 +5371,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.JobServiceGrpcTransport, - transports.JobServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -5779,13 +5385,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.JobServiceGrpcTransport, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.JobServiceGrpcTransport,) def test_job_service_base_transport_error(): @@ -5793,13 +5394,15 @@ def test_job_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_job_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -5808,27 +5411,27 @@ def test_job_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_custom_job', - 'get_custom_job', - 'list_custom_jobs', - 'delete_custom_job', - 'cancel_custom_job', - 'create_data_labeling_job', - 'get_data_labeling_job', - 'list_data_labeling_jobs', - 'delete_data_labeling_job', - 'cancel_data_labeling_job', - 'create_hyperparameter_tuning_job', - 'get_hyperparameter_tuning_job', - 'list_hyperparameter_tuning_jobs', - 'delete_hyperparameter_tuning_job', - 'cancel_hyperparameter_tuning_job', - 'create_batch_prediction_job', - 'get_batch_prediction_job', - 'list_batch_prediction_jobs', - 'delete_batch_prediction_job', - 'cancel_batch_prediction_job', - ) + "create_custom_job", + "get_custom_job", + "list_custom_jobs", + "delete_custom_job", + "cancel_custom_job", + "create_data_labeling_job", + "get_data_labeling_job", + "list_data_labeling_jobs", + "delete_data_labeling_job", + "cancel_data_labeling_job", + "create_hyperparameter_tuning_job", + "get_hyperparameter_tuning_job", + "list_hyperparameter_tuning_jobs", + "delete_hyperparameter_tuning_job", + "cancel_hyperparameter_tuning_job", + "create_batch_prediction_job", + "get_batch_prediction_job", + "list_batch_prediction_jobs", + "delete_batch_prediction_job", + "cancel_batch_prediction_job", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -5841,23 +5444,28 @@ def test_job_service_base_transport(): def test_job_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_job_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport() @@ -5866,11 +5474,11 @@ def test_job_service_base_transport_with_adc(): def test_job_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) JobServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -5878,60 +5486,70 @@ def test_job_service_auth_adc(): def test_job_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.JobServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_host_with_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_job_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -5940,7 +5558,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -5956,26 +5574,27 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -5992,9 +5611,7 @@ def test_job_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -6003,16 +5620,12 @@ def test_job_service_transport_channel_mtls_with_adc( def test_job_service_grpc_lro_client(): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6020,36 +5633,36 @@ def test_job_service_grpc_lro_client(): def test_job_service_grpc_lro_async_client(): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_batch_prediction_job_path(): project = "squid" location = "clam" batch_prediction_job = "whelk" - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) - actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, location=location, batch_prediction_job=batch_prediction_job, + ) + actual = JobServiceClient.batch_prediction_job_path( + project, location, batch_prediction_job + ) assert expected == actual def test_parse_batch_prediction_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", } path = JobServiceClient.batch_prediction_job_path(**expected) @@ -6057,22 +5670,24 @@ def test_parse_batch_prediction_job_path(): actual = JobServiceClient.parse_batch_prediction_job_path(path) assert expected == actual + def test_custom_job_path(): project = "cuttlefish" location = "mussel" custom_job = "winkle" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) actual = JobServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", - + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", } path = JobServiceClient.custom_job_path(**expected) @@ -6080,22 +5695,26 @@ def test_parse_custom_job_path(): actual = JobServiceClient.parse_custom_job_path(path) assert expected == actual + def test_data_labeling_job_path(): project = "squid" location = "clam" data_labeling_job = "whelk" - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) - actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + actual = JobServiceClient.data_labeling_job_path( + project, location, data_labeling_job + ) assert expected == actual def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", } path = JobServiceClient.data_labeling_job_path(**expected) @@ -6103,22 +5722,24 @@ def test_parse_data_labeling_job_path(): actual = JobServiceClient.parse_data_labeling_job_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = JobServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = JobServiceClient.dataset_path(**expected) @@ -6126,22 +5747,28 @@ def test_parse_dataset_path(): actual = JobServiceClient.parse_dataset_path(path) assert expected == actual + def test_hyperparameter_tuning_job_path(): project = "squid" location = "clam" hyperparameter_tuning_job = "whelk" - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) - actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + actual = JobServiceClient.hyperparameter_tuning_job_path( + project, location, hyperparameter_tuning_job + ) assert expected == actual def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "hyperparameter_tuning_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -6149,22 +5776,24 @@ def test_parse_hyperparameter_tuning_job_path(): actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = JobServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = JobServiceClient.model_path(**expected) @@ -6172,18 +5801,20 @@ def test_parse_model_path(): actual = JobServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = JobServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = JobServiceClient.common_billing_account_path(**expected) @@ -6191,18 +5822,18 @@ def test_parse_common_billing_account_path(): actual = JobServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = JobServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = JobServiceClient.common_folder_path(**expected) @@ -6210,18 +5841,18 @@ def test_parse_common_folder_path(): actual = JobServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = JobServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = JobServiceClient.common_organization_path(**expected) @@ -6229,18 +5860,18 @@ def test_parse_common_organization_path(): actual = JobServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = JobServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = JobServiceClient.common_project_path(**expected) @@ -6248,20 +5879,22 @@ def test_parse_common_project_path(): actual = JobServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = JobServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = JobServiceClient.common_location_path(**expected) @@ -6273,17 +5906,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = JobServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 865bcf4305..01aece3a3b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceClient +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceClient, +) from google.cloud.aiplatform_v1beta1.services.migration_service import pagers from google.cloud.aiplatform_v1beta1.services.migration_service import transports from google.cloud.aiplatform_v1beta1.types import migratable_resource @@ -53,7 +57,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -64,17 +72,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [MigrationServiceClient, MigrationServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient] +) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -82,7 +109,7 @@ def test_migration_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_client_get_transport_class(): @@ -93,29 +120,44 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) -def test_migration_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -131,7 +173,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -147,7 +189,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -167,13 +209,15 @@ def test_migration_service_client_client_options(client_class, transport_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -186,26 +230,66 @@ def test_migration_service_client_client_options(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -228,11 +312,21 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -242,7 +336,9 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -257,10 +353,17 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -275,16 +378,23 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -297,16 +407,24 @@ def test_migration_service_client_client_options_scopes(client_class, transport_ client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -321,10 +439,12 @@ def test_migration_service_client_client_options_credentials_file(client_class, def test_migration_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -337,10 +457,12 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -349,12 +471,11 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_migratable_resources(request) @@ -368,7 +489,7 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_migratable_resources_from_dict(): @@ -376,10 +497,9 @@ def test_search_migratable_resources_from_dict(): @pytest.mark.asyncio -async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio'): +async def test_search_migratable_resources_async(transport: str = "grpc_asyncio"): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -388,12 +508,14 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.search_migratable_resources(request) @@ -406,23 +528,21 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -434,10 +554,7 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -449,13 +566,15 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) await client.search_migratable_resources(request) @@ -466,49 +585,39 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources( - parent='parent_value', - ) + client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) @@ -520,24 +629,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources( - parent='parent_value', - ) + response = await client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -550,20 +659,17 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -572,17 +678,14 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -595,9 +698,7 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.search_migratable_resources(request={}) @@ -605,18 +706,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in results) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + def test_search_migratable_resources_pages(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -625,17 +726,14 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -646,19 +744,20 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -667,17 +766,14 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -688,25 +784,27 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in responses) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -715,17 +813,14 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -738,14 +833,15 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -754,10 +850,10 @@ def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_migrate_resources(request) @@ -776,10 +872,9 @@ def test_batch_migrate_resources_from_dict(): @pytest.mark.asyncio -async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio'): +async def test_batch_migrate_resources_async(transport: str = "grpc_asyncio"): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -788,11 +883,11 @@ async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_migrate_resources(request) @@ -808,20 +903,18 @@ async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio'): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_migrate_resources(request) @@ -832,10 +925,7 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -847,13 +937,15 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_migrate_resources(request) @@ -864,29 +956,30 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -894,23 +987,33 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -922,19 +1025,25 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -942,9 +1051,15 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] @pytest.mark.asyncio @@ -958,8 +1073,14 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -970,8 +1091,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -990,8 +1110,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1019,13 +1138,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1033,13 +1155,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MigrationServiceGrpcTransport, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) def test_migration_service_base_transport_error(): @@ -1047,13 +1164,15 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1062,9 +1181,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'search_migratable_resources', - 'batch_migrate_resources', - ) + "search_migratable_resources", + "batch_migrate_resources", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1077,23 +1196,28 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1102,11 +1226,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1114,60 +1238,75 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_migration_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1176,7 +1315,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1192,26 +1331,30 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1228,9 +1371,7 @@ def test_migration_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1239,16 +1380,12 @@ def test_migration_service_transport_channel_mtls_with_adc( def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1256,36 +1393,36 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_annotated_dataset_path(): project = "squid" dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) - actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", - + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1293,22 +1430,24 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1316,22 +1455,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "squid" - location = "clam" - dataset = "whelk" + dataset = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", - + "project": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1339,20 +1478,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): - project = "cuttlefish" - dataset = "mussel" + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "winkle", - "dataset": "nautilus", - + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1360,22 +1503,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - + "project": "clam", + "location": "whelk", + "model": "octopus", } path = MigrationServiceClient.model_path(**expected) @@ -1383,22 +1528,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", - + "project": "mussel", + "location": "winkle", + "model": "nautilus", } path = MigrationServiceClient.model_path(**expected) @@ -1406,22 +1553,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", - + "project": "clam", + "model": "whelk", + "version": "octopus", } path = MigrationServiceClient.version_path(**expected) @@ -1429,18 +1578,20 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", - + "billing_account": "nudibranch", } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1448,18 +1599,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", - + "folder": "mussel", } path = MigrationServiceClient.common_folder_path(**expected) @@ -1467,18 +1618,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", - + "organization": "nautilus", } path = MigrationServiceClient.common_organization_path(**expected) @@ -1486,18 +1637,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", - + "project": "abalone", } path = MigrationServiceClient.common_project_path(**expected) @@ -1505,20 +1656,22 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", - + "project": "whelk", + "location": "octopus", } path = MigrationServiceClient.common_location_path(**expected) @@ -1530,17 +1683,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index bd93f3c4ee..d3c450ffb7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.model_service import ( + ModelServiceAsyncClient, +) from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.services.model_service import transports @@ -65,7 +67,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -76,17 +82,30 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi @pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -94,7 +113,7 @@ def test_model_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_client_get_transport_class(): @@ -105,29 +124,42 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -def test_model_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -143,7 +175,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -159,7 +191,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -179,13 +211,15 @@ def test_model_service_client_client_options(client_class, transport_class, tran client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -198,26 +232,54 @@ def test_model_service_client_client_options(client_class, transport_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -240,11 +302,21 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -254,7 +326,9 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -269,10 +343,17 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -287,16 +368,23 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -309,16 +397,24 @@ def test_model_service_client_client_options_scopes(client_class, transport_clas client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -333,11 +429,11 @@ def test_model_service_client_client_options_credentials_file(client_class, tran def test_model_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -349,10 +445,11 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): +def test_upload_model( + transport: str = "grpc", request_type=model_service.UploadModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -360,11 +457,9 @@ def test_upload_model(transport: str = 'grpc', request_type=model_service.Upload request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.upload_model(request) @@ -383,10 +478,9 @@ def test_upload_model_from_dict(): @pytest.mark.asyncio -async def test_upload_model_async(transport: str = 'grpc_asyncio'): +async def test_upload_model_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -394,12 +488,10 @@ async def test_upload_model_async(transport: str = 'grpc_asyncio'): request = model_service.UploadModelRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.upload_model(request) @@ -415,20 +507,16 @@ async def test_upload_model_async(transport: str = 'grpc_asyncio'): def test_upload_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.upload_model(request) @@ -439,28 +527,23 @@ def test_upload_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.upload_model(request) @@ -471,29 +554,21 @@ async def test_upload_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_upload_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -501,47 +576,40 @@ def test_upload_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") def test_upload_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) @pytest.mark.asyncio async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -549,31 +617,28 @@ async def test_upload_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") @pytest.mark.asyncio async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) -def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): +def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -581,31 +646,21 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.get_model(request) @@ -619,25 +674,31 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_model_from_dict(): @@ -645,10 +706,9 @@ def test_get_model_from_dict(): @pytest.mark.asyncio -async def test_get_model_async(transport: str = 'grpc_asyncio'): +async def test_get_model_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -656,22 +716,28 @@ async def test_get_model_async(transport: str = 'grpc_asyncio'): request = model_service.GetModelRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.get_model(request) @@ -684,41 +750,43 @@ async def test_get_model_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = model.Model() client.get_model(request) @@ -730,27 +798,20 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -762,99 +823,79 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model( - name='name_value', - ) + client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model( - name='name_value', - ) + response = await client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) -def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): +def test_list_models( + transport: str = "grpc", request_type=model_service.ListModelsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -862,13 +903,10 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_models(request) @@ -882,7 +920,7 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_models_from_dict(): @@ -890,10 +928,9 @@ def test_list_models_from_dict(): @pytest.mark.asyncio -async def test_list_models_async(transport: str = 'grpc_asyncio'): +async def test_list_models_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -901,13 +938,11 @@ async def test_list_models_async(transport: str = 'grpc_asyncio'): request = model_service.ListModelsRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_models(request) @@ -920,23 +955,19 @@ async def test_list_models_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_models_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: call.return_value = model_service.ListModelsResponse() client.list_models(request) @@ -948,28 +979,23 @@ def test_list_models_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) await client.list_models(request) @@ -980,138 +1006,98 @@ async def test_list_models_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_models_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_models( - parent='parent_value', - ) + client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_models_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_models( - parent='parent_value', - ) + response = await client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) def test_list_models_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_models(request={}) @@ -1119,147 +1105,96 @@ def test_list_models_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) + assert all(isinstance(i, model.Model) for i in results) + def test_list_models_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_models_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) - for i in responses) + assert all(isinstance(i, model.Model) for i in responses) + @pytest.mark.asyncio async def test_list_models_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_models(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): +def test_update_model( + transport: str = "grpc", request_type=model_service.UpdateModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1267,31 +1202,21 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.update_model(request) @@ -1305,25 +1230,31 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_model_from_dict(): @@ -1331,10 +1262,9 @@ def test_update_model_from_dict(): @pytest.mark.asyncio -async def test_update_model_async(transport: str = 'grpc_asyncio'): +async def test_update_model_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1342,22 +1272,28 @@ async def test_update_model_async(transport: str = 'grpc_asyncio'): request = model_service.UpdateModelRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.update_model(request) @@ -1370,41 +1306,43 @@ async def test_update_model_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = gca_model.Model() client.update_model(request) @@ -1416,27 +1354,20 @@ def test_update_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) await client.update_model(request) @@ -1448,29 +1379,22 @@ async def test_update_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] def test_update_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1478,36 +1402,30 @@ def test_update_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() @@ -1515,8 +1433,8 @@ async def test_update_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1524,31 +1442,30 @@ async def test_update_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): +def test_delete_model( + transport: str = "grpc", request_type=model_service.DeleteModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1556,11 +1473,9 @@ def test_delete_model(transport: str = 'grpc', request_type=model_service.Delete request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_model(request) @@ -1579,10 +1494,9 @@ def test_delete_model_from_dict(): @pytest.mark.asyncio -async def test_delete_model_async(transport: str = 'grpc_asyncio'): +async def test_delete_model_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1590,12 +1504,10 @@ async def test_delete_model_async(transport: str = 'grpc_asyncio'): request = model_service.DeleteModelRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_model(request) @@ -1611,20 +1523,16 @@ async def test_delete_model_async(transport: str = 'grpc_asyncio'): def test_delete_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_model(request) @@ -1635,28 +1543,23 @@ def test_delete_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_model(request) @@ -1667,101 +1570,81 @@ async def test_delete_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model( - name='name_value', - ) + client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_model( - name='name_value', - ) + response = await client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) -def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): +def test_export_model( + transport: str = "grpc", request_type=model_service.ExportModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1769,11 +1652,9 @@ def test_export_model(transport: str = 'grpc', request_type=model_service.Export request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_model(request) @@ -1792,10 +1673,9 @@ def test_export_model_from_dict(): @pytest.mark.asyncio -async def test_export_model_async(transport: str = 'grpc_asyncio'): +async def test_export_model_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1803,12 +1683,10 @@ async def test_export_model_async(transport: str = 'grpc_asyncio'): request = model_service.ExportModelRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_model(request) @@ -1824,20 +1702,16 @@ async def test_export_model_async(transport: str = 'grpc_asyncio'): def test_export_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_model(request) @@ -1848,28 +1722,23 @@ def test_export_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_model(request) @@ -1880,29 +1749,24 @@ async def test_export_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_export_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -1910,47 +1774,47 @@ def test_export_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) def test_export_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) @pytest.mark.asyncio async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -1958,31 +1822,34 @@ async def test_export_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) @pytest.mark.asyncio async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) -def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): +def test_get_model_evaluation( + transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1991,16 +1858,13 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - - slice_dimensions=['slice_dimensions_value'], - + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], ) response = client.get_model_evaluation(request) @@ -2014,11 +1878,11 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] def test_get_model_evaluation_from_dict(): @@ -2026,10 +1890,9 @@ def test_get_model_evaluation_from_dict(): @pytest.mark.asyncio -async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio'): +async def test_get_model_evaluation_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2038,14 +1901,16 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - slice_dimensions=['slice_dimensions_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + ) response = await client.get_model_evaluation(request) @@ -2058,27 +1923,25 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] def test_get_model_evaluation_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: call.return_value = model_evaluation.ModelEvaluation() client.get_model_evaluation(request) @@ -2090,28 +1953,25 @@ def test_get_model_evaluation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) await client.get_model_evaluation(request) @@ -2122,99 +1982,85 @@ async def test_get_model_evaluation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation( - name='name_value', - ) + client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation( - name='name_value', - ) + response = await client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) -def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): +def test_list_model_evaluations( + transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2223,12 +2069,11 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluations(request) @@ -2242,7 +2087,7 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluations_from_dict(): @@ -2250,10 +2095,9 @@ def test_list_model_evaluations_from_dict(): @pytest.mark.asyncio -async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio'): +async def test_list_model_evaluations_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2262,12 +2106,14 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluations(request) @@ -2280,23 +2126,21 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluations_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationsResponse() client.list_model_evaluations(request) @@ -2308,28 +2152,25 @@ def test_list_model_evaluations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) await client.list_model_evaluations(request) @@ -2340,104 +2181,87 @@ async def test_list_model_evaluations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluations_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluations( - parent='parent_value', - ) + client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluations( - parent='parent_value', - ) + response = await client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) def test_list_model_evaluations_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2446,17 +2270,14 @@ def test_list_model_evaluations_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2469,9 +2290,7 @@ def test_list_model_evaluations_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluations(request={}) @@ -2479,18 +2298,16 @@ def test_list_model_evaluations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in results) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) + def test_list_model_evaluations_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2499,17 +2316,14 @@ def test_list_model_evaluations_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2520,19 +2334,20 @@ def test_list_model_evaluations_pages(): RuntimeError, ) pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2541,17 +2356,14 @@ async def test_list_model_evaluations_async_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2562,25 +2374,25 @@ async def test_list_model_evaluations_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in responses) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) + @pytest.mark.asyncio async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2589,17 +2401,14 @@ async def test_list_model_evaluations_async_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2612,14 +2421,15 @@ async def test_list_model_evaluations_async_pages(): pages = [] async for page_ in (await client.list_model_evaluations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): +def test_get_model_evaluation_slice( + transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2628,14 +2438,11 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - + name="name_value", metrics_schema_uri="metrics_schema_uri_value", ) response = client.get_model_evaluation_slice(request) @@ -2649,9 +2456,9 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" def test_get_model_evaluation_slice_from_dict(): @@ -2659,10 +2466,9 @@ def test_get_model_evaluation_slice_from_dict(): @pytest.mark.asyncio -async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio'): +async def test_get_model_evaluation_slice_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2671,13 +2477,14 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio') # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + ) response = await client.get_model_evaluation_slice(request) @@ -2690,25 +2497,23 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio') # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: call.return_value = model_evaluation_slice.ModelEvaluationSlice() client.get_model_evaluation_slice(request) @@ -2720,28 +2525,25 @@ def test_get_model_evaluation_slice_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) await client.get_model_evaluation_slice(request) @@ -2752,99 +2554,85 @@ async def test_get_model_evaluation_slice_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation_slice( - name='name_value', - ) + client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice( - name='name_value', - ) + response = await client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) -def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): +def test_list_model_evaluation_slices( + transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2853,12 +2641,11 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluation_slices(request) @@ -2872,7 +2659,7 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluation_slices_from_dict(): @@ -2880,10 +2667,9 @@ def test_list_model_evaluation_slices_from_dict(): @pytest.mark.asyncio -async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio'): +async def test_list_model_evaluation_slices_async(transport: str = "grpc_asyncio"): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2892,12 +2678,14 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluation_slices(request) @@ -2910,23 +2698,21 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationSlicesResponse() client.list_model_evaluation_slices(request) @@ -2938,28 +2724,25 @@ def test_list_model_evaluation_slices_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) await client.list_model_evaluation_slices(request) @@ -2970,104 +2753,87 @@ async def test_list_model_evaluation_slices_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluation_slices( - parent='parent_value', - ) + client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices( - parent='parent_value', - ) + response = await client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3076,17 +2842,16 @@ def test_list_model_evaluation_slices_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3099,9 +2864,7 @@ def test_list_model_evaluation_slices_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluation_slices(request={}) @@ -3109,18 +2872,18 @@ def test_list_model_evaluation_slices_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in results) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results + ) + def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3129,17 +2892,16 @@ def test_list_model_evaluation_slices_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3150,19 +2912,20 @@ def test_list_model_evaluation_slices_pages(): RuntimeError, ) pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3171,17 +2934,16 @@ async def test_list_model_evaluation_slices_async_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3192,25 +2954,28 @@ async def test_list_model_evaluation_slices_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses + ) + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3219,17 +2984,16 @@ async def test_list_model_evaluation_slices_async_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3240,9 +3004,11 @@ async def test_list_model_evaluation_slices_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_model_evaluation_slices(request={})).pages: + async for page_ in ( + await client.list_model_evaluation_slices(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3253,8 +3019,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3273,8 +3038,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3302,13 +3066,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3316,13 +3080,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.ModelServiceGrpcTransport, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) def test_model_service_base_transport_error(): @@ -3330,13 +3089,15 @@ def test_model_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3345,17 +3106,17 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'upload_model', - 'get_model', - 'list_models', - 'update_model', - 'delete_model', - 'export_model', - 'get_model_evaluation', - 'list_model_evaluations', - 'get_model_evaluation_slice', - 'list_model_evaluation_slices', - ) + "upload_model", + "get_model", + "list_models", + "update_model", + "delete_model", + "export_model", + "get_model_evaluation", + "list_model_evaluations", + "get_model_evaluation_slice", + "list_model_evaluation_slices", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3368,23 +3129,28 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -3393,11 +3159,11 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) ModelServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3405,60 +3171,70 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_host_with_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_model_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3467,7 +3243,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3483,26 +3259,27 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3519,9 +3296,7 @@ def test_model_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3530,16 +3305,12 @@ def test_model_service_transport_channel_mtls_with_adc( def test_model_service_grpc_lro_client(): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3547,36 +3318,34 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = ModelServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = ModelServiceClient.endpoint_path(**expected) @@ -3584,22 +3353,24 @@ def test_parse_endpoint_path(): actual = ModelServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = ModelServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = ModelServiceClient.model_path(**expected) @@ -3607,24 +3378,28 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual + def test_model_evaluation_path(): project = "squid" location = "clam" model = "whelk" evaluation = "octopus" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) - actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + actual = ModelServiceClient.model_evaluation_path( + project, location, model, evaluation + ) assert expected == actual def test_parse_model_evaluation_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "model": "cuttlefish", - "evaluation": "mussel", - + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", } path = ModelServiceClient.model_evaluation_path(**expected) @@ -3632,6 +3407,7 @@ def test_parse_model_evaluation_path(): actual = ModelServiceClient.parse_model_evaluation_path(path) assert expected == actual + def test_model_evaluation_slice_path(): project = "winkle" location = "nautilus" @@ -3639,19 +3415,26 @@ def test_model_evaluation_slice_path(): evaluation = "abalone" slice = "squid" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) - actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + actual = ModelServiceClient.model_evaluation_slice_path( + project, location, model, evaluation, slice + ) assert expected == actual def test_parse_model_evaluation_slice_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - "evaluation": "oyster", - "slice": "nudibranch", - + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", } path = ModelServiceClient.model_evaluation_slice_path(**expected) @@ -3659,22 +3442,26 @@ def test_parse_model_evaluation_slice_path(): actual = ModelServiceClient.parse_model_evaluation_slice_path(path) assert expected == actual + def test_training_pipeline_path(): project = "cuttlefish" location = "mussel" training_pipeline = "winkle" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = ModelServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "nautilus", - "location": "scallop", - "training_pipeline": "abalone", - + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", } path = ModelServiceClient.training_pipeline_path(**expected) @@ -3682,18 +3469,20 @@ def test_parse_training_pipeline_path(): actual = ModelServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = ModelServiceClient.common_billing_account_path(**expected) @@ -3701,18 +3490,18 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = ModelServiceClient.common_folder_path(**expected) @@ -3720,18 +3509,18 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = ModelServiceClient.common_organization_path(**expected) @@ -3739,18 +3528,18 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = ModelServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = ModelServiceClient.common_project_path(**expected) @@ -3758,20 +3547,22 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = ModelServiceClient.common_location_path(**expected) @@ -3783,17 +3574,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 82f1b4f546..8a60b0a966 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceClient +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceClient, +) from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports from google.cloud.aiplatform_v1beta1.types import deployed_model_ref @@ -49,7 +53,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import any_pb2 as gp_any # type: ignore @@ -67,7 +73,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -78,17 +88,35 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [PipelineServiceClient, PipelineServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] +) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -96,7 +124,7 @@ def test_pipeline_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_client_get_transport_class(): @@ -107,29 +135,44 @@ def test_pipeline_service_client_get_transport_class(): assert transport == transports.PipelineServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) -def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -145,7 +188,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -161,7 +204,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -181,13 +224,15 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -200,26 +245,66 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "true", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "false", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_pipeline_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -242,11 +327,21 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -256,7 +351,9 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -271,10 +368,17 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -289,16 +393,23 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -311,16 +422,24 @@ def test_pipeline_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -335,10 +454,12 @@ def test_pipeline_service_client_client_options_credentials_file(client_class, t def test_pipeline_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = PipelineServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -351,10 +472,11 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): +def test_create_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -363,18 +485,14 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.create_training_pipeline(request) @@ -388,11 +506,11 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -402,10 +520,9 @@ def test_create_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio'): +async def test_create_training_pipeline_async(transport: str = "grpc_asyncio"): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -414,15 +531,17 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.create_training_pipeline(request) @@ -435,29 +554,27 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: call.return_value = gca_training_pipeline.TrainingPipeline() client.create_training_pipeline(request) @@ -469,28 +586,25 @@ def test_create_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) await client.create_training_pipeline(request) @@ -501,29 +615,24 @@ async def test_create_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -531,45 +640,45 @@ def test_create_training_pipeline_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -577,31 +686,32 @@ async def test_create_training_pipeline_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) -def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): +def test_get_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -610,18 +720,14 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.get_training_pipeline(request) @@ -635,11 +741,11 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -649,10 +755,9 @@ def test_get_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio'): +async def test_get_training_pipeline_async(transport: str = "grpc_asyncio"): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -661,15 +766,17 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.get_training_pipeline(request) @@ -682,29 +789,27 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: call.return_value = training_pipeline.TrainingPipeline() client.get_training_pipeline(request) @@ -716,28 +821,25 @@ def test_get_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) await client.get_training_pipeline(request) @@ -748,99 +850,85 @@ async def test_get_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_training_pipeline( - name='name_value', - ) + client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_training_pipeline( - name='name_value', - ) + response = await client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) -def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): +def test_list_training_pipelines( + transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -849,12 +937,11 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_training_pipelines(request) @@ -868,7 +955,7 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_training_pipelines_from_dict(): @@ -876,10 +963,9 @@ def test_list_training_pipelines_from_dict(): @pytest.mark.asyncio -async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio'): +async def test_list_training_pipelines_async(transport: str = "grpc_asyncio"): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -888,12 +974,14 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_training_pipelines(request) @@ -906,23 +994,21 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: call.return_value = pipeline_service.ListTrainingPipelinesResponse() client.list_training_pipelines(request) @@ -934,28 +1020,25 @@ def test_list_training_pipelines_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) await client.list_training_pipelines(request) @@ -966,104 +1049,87 @@ async def test_list_training_pipelines_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_training_pipelines_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_training_pipelines( - parent='parent_value', - ) + client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_training_pipelines( - parent='parent_value', - ) + response = await client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) def test_list_training_pipelines_pager(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1072,17 +1138,14 @@ def test_list_training_pipelines_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1095,9 +1158,7 @@ def test_list_training_pipelines_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_training_pipelines(request={}) @@ -1105,18 +1166,16 @@ def test_list_training_pipelines_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in results) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) + def test_list_training_pipelines_pages(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1125,17 +1184,14 @@ def test_list_training_pipelines_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1146,19 +1202,20 @@ def test_list_training_pipelines_pages(): RuntimeError, ) pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1167,17 +1224,14 @@ async def test_list_training_pipelines_async_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1188,25 +1242,25 @@ async def test_list_training_pipelines_async_pager(): RuntimeError, ) async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in responses) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) + @pytest.mark.asyncio async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1215,17 +1269,14 @@ async def test_list_training_pipelines_async_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1238,14 +1289,15 @@ async def test_list_training_pipelines_async_pages(): pages = [] async for page_ in (await client.list_training_pipelines(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): +def test_delete_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1254,10 +1306,10 @@ def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_training_pipeline(request) @@ -1276,10 +1328,9 @@ def test_delete_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio'): +async def test_delete_training_pipeline_async(transport: str = "grpc_asyncio"): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1288,11 +1339,11 @@ async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_training_pipeline(request) @@ -1308,20 +1359,18 @@ async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio'): def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_training_pipeline(request) @@ -1332,28 +1381,25 @@ def test_delete_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_training_pipeline(request) @@ -1364,101 +1410,85 @@ async def test_delete_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_training_pipeline( - name='name_value', - ) + client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_training_pipeline( - name='name_value', - ) + response = await client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) -def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): +def test_cancel_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1467,8 +1497,8 @@ def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1489,10 +1519,9 @@ def test_cancel_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio'): +async def test_cancel_training_pipeline_async(transport: str = "grpc_asyncio"): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1501,8 +1530,8 @@ async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1519,19 +1548,17 @@ async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio'): def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = None client.cancel_training_pipeline(request) @@ -1543,27 +1570,22 @@ def test_cancel_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_training_pipeline(request) @@ -1575,92 +1597,75 @@ async def test_cancel_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_training_pipeline( - name='name_value', - ) + client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_training_pipeline( - name='name_value', - ) + response = await client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @@ -1671,8 +1676,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1691,8 +1695,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1720,13 +1723,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1734,13 +1740,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.PipelineServiceGrpcTransport, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) def test_pipeline_service_base_transport_error(): @@ -1748,13 +1749,15 @@ def test_pipeline_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_pipeline_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1763,12 +1766,12 @@ def test_pipeline_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_training_pipeline', - 'get_training_pipeline', - 'list_training_pipelines', - 'delete_training_pipeline', - 'cancel_training_pipeline', - ) + "create_training_pipeline", + "get_training_pipeline", + "list_training_pipelines", + "delete_training_pipeline", + "cancel_training_pipeline", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1781,23 +1784,28 @@ def test_pipeline_service_base_transport(): def test_pipeline_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_pipeline_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport() @@ -1806,11 +1814,11 @@ def test_pipeline_service_base_transport_with_adc(): def test_pipeline_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PipelineServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1818,60 +1826,75 @@ def test_pipeline_service_auth_adc(): def test_pipeline_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_host_with_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_pipeline_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1880,7 +1903,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1896,26 +1919,30 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) -def test_pipeline_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1932,9 +1959,7 @@ def test_pipeline_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1943,16 +1968,12 @@ def test_pipeline_service_transport_channel_mtls_with_adc( def test_pipeline_service_grpc_lro_client(): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1960,36 +1981,34 @@ def test_pipeline_service_grpc_lro_client(): def test_pipeline_service_grpc_lro_async_client(): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = PipelineServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = PipelineServiceClient.endpoint_path(**expected) @@ -1997,22 +2016,24 @@ def test_parse_endpoint_path(): actual = PipelineServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = PipelineServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = PipelineServiceClient.model_path(**expected) @@ -2020,22 +2041,26 @@ def test_parse_model_path(): actual = PipelineServiceClient.parse_model_path(path) assert expected == actual + def test_training_pipeline_path(): project = "squid" location = "clam" training_pipeline = "whelk" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = PipelineServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", - + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2043,18 +2068,20 @@ def test_parse_training_pipeline_path(): actual = PipelineServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = PipelineServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2062,18 +2089,18 @@ def test_parse_common_billing_account_path(): actual = PipelineServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = PipelineServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = PipelineServiceClient.common_folder_path(**expected) @@ -2081,18 +2108,18 @@ def test_parse_common_folder_path(): actual = PipelineServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = PipelineServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = PipelineServiceClient.common_organization_path(**expected) @@ -2100,18 +2127,18 @@ def test_parse_common_organization_path(): actual = PipelineServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = PipelineServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = PipelineServiceClient.common_project_path(**expected) @@ -2119,20 +2146,22 @@ def test_parse_common_project_path(): actual = PipelineServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = PipelineServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = PipelineServiceClient.common_location_path(**expected) @@ -2144,17 +2173,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = PipelineServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index 9934ffb497..2e91a47bf4 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -32,8 +32,12 @@ from google.api_core import grpc_helpers_async from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.prediction_service import PredictionServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.prediction_service import PredictionServiceClient +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + PredictionServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + PredictionServiceClient, +) from google.cloud.aiplatform_v1beta1.services.prediction_service import transports from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service @@ -49,7 +53,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -60,17 +68,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PredictionServiceClient._get_default_mtls_endpoint(None) is None - assert PredictionServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert PredictionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert PredictionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert PredictionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert PredictionServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + PredictionServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [PredictionServiceClient, PredictionServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] +) def test_prediction_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -78,7 +105,7 @@ def test_prediction_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_prediction_service_client_get_transport_class(): @@ -89,29 +116,44 @@ def test_prediction_service_client_get_transport_class(): assert transport == transports.PredictionServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), - (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(PredictionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceClient)) -@mock.patch.object(PredictionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceAsyncClient)) -def test_prediction_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PredictionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceAsyncClient), +) +def test_prediction_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PredictionServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PredictionServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -127,7 +169,7 @@ def test_prediction_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -143,7 +185,7 @@ def test_prediction_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -163,13 +205,15 @@ def test_prediction_service_client_client_options(client_class, transport_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -182,26 +226,66 @@ def test_prediction_service_client_client_options(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc", "true"), - (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc", "false"), - (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(PredictionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceClient)) -@mock.patch.object(PredictionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PredictionServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + "true", + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + "false", + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PredictionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_prediction_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_prediction_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -224,11 +308,21 @@ def test_prediction_service_client_mtls_env_auto(client_class, transport_class, # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -238,7 +332,9 @@ def test_prediction_service_client_mtls_env_auto(client_class, transport_class, is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -253,10 +349,17 @@ def test_prediction_service_client_mtls_env_auto(client_class, transport_class, ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -271,16 +374,23 @@ def test_prediction_service_client_mtls_env_auto(client_class, transport_class, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), - (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_prediction_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_prediction_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -293,16 +403,24 @@ def test_prediction_service_client_client_options_scopes(client_class, transport client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), - (PredictionServiceAsyncClient, transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_prediction_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_prediction_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -317,10 +435,12 @@ def test_prediction_service_client_client_options_credentials_file(client_class, def test_prediction_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = PredictionServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -333,10 +453,11 @@ def test_prediction_service_client_client_options_from_dict(): ) -def test_predict(transport: str = 'grpc', request_type=prediction_service.PredictRequest): +def test_predict( + transport: str = "grpc", request_type=prediction_service.PredictRequest +): client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -344,13 +465,10 @@ def test_predict(transport: str = 'grpc', request_type=prediction_service.Predic request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.predict), - '__call__') as call: + with mock.patch.object(type(client.transport.predict), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = prediction_service.PredictResponse( - deployed_model_id='deployed_model_id_value', - + deployed_model_id="deployed_model_id_value", ) response = client.predict(request) @@ -364,7 +482,7 @@ def test_predict(transport: str = 'grpc', request_type=prediction_service.Predic # Establish that the response is the type that we expect. assert isinstance(response, prediction_service.PredictResponse) - assert response.deployed_model_id == 'deployed_model_id_value' + assert response.deployed_model_id == "deployed_model_id_value" def test_predict_from_dict(): @@ -372,10 +490,9 @@ def test_predict_from_dict(): @pytest.mark.asyncio -async def test_predict_async(transport: str = 'grpc_asyncio'): +async def test_predict_async(transport: str = "grpc_asyncio"): client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -383,13 +500,13 @@ async def test_predict_async(transport: str = 'grpc_asyncio'): request = prediction_service.PredictRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.predict), - '__call__') as call: + with mock.patch.object(type(client.transport.predict), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.PredictResponse( - deployed_model_id='deployed_model_id_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse( + deployed_model_id="deployed_model_id_value", + ) + ) response = await client.predict(request) @@ -402,23 +519,19 @@ async def test_predict_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, prediction_service.PredictResponse) - assert response.deployed_model_id == 'deployed_model_id_value' + assert response.deployed_model_id == "deployed_model_id_value" def test_predict_field_headers(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = prediction_service.PredictRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.predict), - '__call__') as call: + with mock.patch.object(type(client.transport.predict), "__call__") as call: call.return_value = prediction_service.PredictResponse() client.predict(request) @@ -430,10 +543,7 @@ def test_predict_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio @@ -445,13 +555,13 @@ async def test_predict_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = prediction_service.PredictRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.predict), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.PredictResponse()) + with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) await client.predict(request) @@ -462,28 +572,21 @@ async def test_predict_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_predict_flattened(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.predict), - '__call__') as call: + with mock.patch.object(type(client.transport.predict), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = prediction_service.PredictResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.predict( - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), ) @@ -493,25 +596,25 @@ def test_predict_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] # https://github.com/googleapis/gapic-generator-python/issues/414 # assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) def test_predict_flattened_error(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.predict( prediction_service.PredictRequest(), - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), ) @@ -524,17 +627,17 @@ async def test_predict_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.predict), - '__call__') as call: + with mock.patch.object(type(client.transport.predict), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = prediction_service.PredictResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.PredictResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.predict( - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), ) @@ -544,11 +647,15 @@ async def test_predict_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] - assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + assert args[0].parameters == struct.Value( + null_value=struct.NullValue.NULL_VALUE + ) @pytest.mark.asyncio @@ -562,16 +669,17 @@ async def test_predict_flattened_error_async(): with pytest.raises(ValueError): await client.predict( prediction_service.PredictRequest(), - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), ) -def test_explain(transport: str = 'grpc', request_type=prediction_service.ExplainRequest): +def test_explain( + transport: str = "grpc", request_type=prediction_service.ExplainRequest +): client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -579,13 +687,10 @@ def test_explain(transport: str = 'grpc', request_type=prediction_service.Explai request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.explain), - '__call__') as call: + with mock.patch.object(type(client.transport.explain), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = prediction_service.ExplainResponse( - deployed_model_id='deployed_model_id_value', - + deployed_model_id="deployed_model_id_value", ) response = client.explain(request) @@ -599,7 +704,7 @@ def test_explain(transport: str = 'grpc', request_type=prediction_service.Explai # Establish that the response is the type that we expect. assert isinstance(response, prediction_service.ExplainResponse) - assert response.deployed_model_id == 'deployed_model_id_value' + assert response.deployed_model_id == "deployed_model_id_value" def test_explain_from_dict(): @@ -607,10 +712,9 @@ def test_explain_from_dict(): @pytest.mark.asyncio -async def test_explain_async(transport: str = 'grpc_asyncio'): +async def test_explain_async(transport: str = "grpc_asyncio"): client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -618,13 +722,13 @@ async def test_explain_async(transport: str = 'grpc_asyncio'): request = prediction_service.ExplainRequest() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.explain), - '__call__') as call: + with mock.patch.object(type(client.transport.explain), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.ExplainResponse( - deployed_model_id='deployed_model_id_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.ExplainResponse( + deployed_model_id="deployed_model_id_value", + ) + ) response = await client.explain(request) @@ -637,23 +741,19 @@ async def test_explain_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, prediction_service.ExplainResponse) - assert response.deployed_model_id == 'deployed_model_id_value' + assert response.deployed_model_id == "deployed_model_id_value" def test_explain_field_headers(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = prediction_service.ExplainRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.explain), - '__call__') as call: + with mock.patch.object(type(client.transport.explain), "__call__") as call: call.return_value = prediction_service.ExplainResponse() client.explain(request) @@ -665,10 +765,7 @@ def test_explain_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio @@ -680,13 +777,13 @@ async def test_explain_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = prediction_service.ExplainRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.explain), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.ExplainResponse()) + with mock.patch.object(type(client.transport.explain), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.ExplainResponse() + ) await client.explain(request) @@ -697,31 +794,24 @@ async def test_explain_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_explain_flattened(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.explain), - '__call__') as call: + with mock.patch.object(type(client.transport.explain), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = prediction_service.ExplainResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.explain( - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id='deployed_model_id_value', + deployed_model_id="deployed_model_id_value", ) # Establish that the underlying call was made with the expected @@ -729,30 +819,30 @@ def test_explain_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] # https://github.com/googleapis/gapic-generator-python/issues/414 # assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" def test_explain_flattened_error(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.explain( prediction_service.ExplainRequest(), - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id='deployed_model_id_value', + deployed_model_id="deployed_model_id_value", ) @@ -763,20 +853,20 @@ async def test_explain_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.explain), - '__call__') as call: + with mock.patch.object(type(client.transport.explain), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = prediction_service.ExplainResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(prediction_service.ExplainResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.ExplainResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.explain( - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id='deployed_model_id_value', + deployed_model_id="deployed_model_id_value", ) # Establish that the underlying call was made with the expected @@ -784,13 +874,17 @@ async def test_explain_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [struct.Value(null_value=struct.NullValue.NULL_VALUE)] + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] - assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + assert args[0].parameters == struct.Value( + null_value=struct.NullValue.NULL_VALUE + ) - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" @pytest.mark.asyncio @@ -804,10 +898,10 @@ async def test_explain_flattened_error_async(): with pytest.raises(ValueError): await client.explain( prediction_service.ExplainRequest(), - endpoint='endpoint_value', + endpoint="endpoint_value", instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id='deployed_model_id_value', + deployed_model_id="deployed_model_id_value", ) @@ -818,8 +912,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -838,8 +931,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PredictionServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -867,13 +959,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.PredictionServiceGrpcTransport, - transports.PredictionServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -881,13 +976,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.PredictionServiceGrpcTransport, - ) + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PredictionServiceGrpcTransport,) def test_prediction_service_base_transport_error(): @@ -895,13 +985,15 @@ def test_prediction_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PredictionServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_prediction_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.PredictionServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -910,9 +1002,9 @@ def test_prediction_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'predict', - 'explain', - ) + "predict", + "explain", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -920,23 +1012,28 @@ def test_prediction_service_base_transport(): def test_prediction_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PredictionServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_prediction_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PredictionServiceTransport() @@ -945,11 +1042,11 @@ def test_prediction_service_base_transport_with_adc(): def test_prediction_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PredictionServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -957,60 +1054,75 @@ def test_prediction_service_auth_adc(): def test_prediction_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PredictionServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.PredictionServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_prediction_service_host_no_port(): client = PredictionServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_prediction_service_host_with_port(): client = PredictionServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_prediction_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.PredictionServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_prediction_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.PredictionServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.PredictionServiceGrpcTransport, transports.PredictionServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) def test_prediction_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1019,7 +1131,7 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1035,26 +1147,30 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.PredictionServiceGrpcTransport, transports.PredictionServiceGrpcAsyncIOTransport]) -def test_prediction_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1071,9 +1187,7 @@ def test_prediction_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1085,17 +1199,18 @@ def test_endpoint_path(): location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = PredictionServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = PredictionServiceClient.endpoint_path(**expected) @@ -1103,18 +1218,20 @@ def test_parse_endpoint_path(): actual = PredictionServiceClient.parse_endpoint_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = PredictionServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = PredictionServiceClient.common_billing_account_path(**expected) @@ -1122,18 +1239,18 @@ def test_parse_common_billing_account_path(): actual = PredictionServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = PredictionServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = PredictionServiceClient.common_folder_path(**expected) @@ -1141,18 +1258,18 @@ def test_parse_common_folder_path(): actual = PredictionServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = PredictionServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = PredictionServiceClient.common_organization_path(**expected) @@ -1160,18 +1277,18 @@ def test_parse_common_organization_path(): actual = PredictionServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = PredictionServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = PredictionServiceClient.common_project_path(**expected) @@ -1179,20 +1296,22 @@ def test_parse_common_project_path(): actual = PredictionServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = PredictionServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = PredictionServiceClient.common_location_path(**expected) @@ -1204,17 +1323,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.PredictionServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PredictionServiceTransport, "_prep_wrapped_messages" + ) as prep: client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.PredictionServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PredictionServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = PredictionServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index 66d80fad2a..bb0461f5ee 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceClient +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceClient, +) from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -56,7 +60,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -67,17 +75,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] +) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -85,7 +112,7 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_client_get_transport_class(): @@ -96,29 +123,48 @@ def test_specialist_pool_service_client_get_transport_class(): assert transport == transports.SpecialistPoolServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) -def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -134,7 +180,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -150,7 +196,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -170,13 +216,15 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -189,26 +237,66 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "true", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "false", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_specialist_pool_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -231,11 +319,21 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -245,7 +343,9 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -260,10 +360,17 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -278,16 +385,27 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -300,16 +418,28 @@ def test_specialist_pool_service_client_client_options_scopes(client_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -324,10 +454,12 @@ def test_specialist_pool_service_client_client_options_credentials_file(client_c def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = SpecialistPoolServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -340,10 +472,12 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): +def test_create_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -352,10 +486,10 @@ def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_specialist_pool(request) @@ -374,10 +508,9 @@ def test_create_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio'): +async def test_create_specialist_pool_async(transport: str = "grpc_asyncio"): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -386,11 +519,11 @@ async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_specialist_pool(request) @@ -413,13 +546,13 @@ def test_create_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_specialist_pool(request) @@ -430,10 +563,7 @@ def test_create_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -445,13 +575,15 @@ async def test_create_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_specialist_pool(request) @@ -462,10 +594,7 @@ async def test_create_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_specialist_pool_flattened(): @@ -475,16 +604,16 @@ def test_create_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -492,9 +621,11 @@ def test_create_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) def test_create_specialist_pool_flattened_error(): @@ -507,8 +638,8 @@ def test_create_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) @@ -520,19 +651,19 @@ async def test_create_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -540,9 +671,11 @@ async def test_create_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) @pytest.mark.asyncio @@ -556,15 +689,17 @@ async def test_create_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) -def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): +def test_get_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -573,20 +708,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", specialist_managers_count=2662, - - specialist_manager_emails=['specialist_manager_emails_value'], - - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], ) response = client.get_specialist_pool(request) @@ -600,15 +730,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] def test_get_specialist_pool_from_dict(): @@ -616,10 +746,9 @@ def test_get_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio'): +async def test_get_specialist_pool_async(transport: str = "grpc_asyncio"): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -628,16 +757,18 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( - name='name_value', - display_name='display_name_value', - specialist_managers_count=2662, - specialist_manager_emails=['specialist_manager_emails_value'], - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + ) response = await client.get_specialist_pool(request) @@ -650,15 +781,15 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] def test_get_specialist_pool_field_headers(): @@ -669,12 +800,12 @@ def test_get_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: call.return_value = specialist_pool.SpecialistPool() client.get_specialist_pool(request) @@ -686,10 +817,7 @@ def test_get_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -701,13 +829,15 @@ async def test_get_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) await client.get_specialist_pool(request) @@ -718,10 +848,7 @@ async def test_get_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_specialist_pool_flattened(): @@ -731,23 +858,21 @@ def test_get_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_specialist_pool( - name='name_value', - ) + client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_specialist_pool_flattened_error(): @@ -759,8 +884,7 @@ def test_get_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) @@ -772,24 +896,24 @@ async def test_get_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_specialist_pool( - name='name_value', - ) + response = await client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -802,15 +926,16 @@ async def test_get_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) -def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): +def test_list_specialist_pools( + transport: str = "grpc", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -819,12 +944,11 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_specialist_pools(request) @@ -838,7 +962,7 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_specialist_pools_from_dict(): @@ -846,10 +970,9 @@ def test_list_specialist_pools_from_dict(): @pytest.mark.asyncio -async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio'): +async def test_list_specialist_pools_async(transport: str = "grpc_asyncio"): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -858,12 +981,14 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_specialist_pools(request) @@ -876,7 +1001,7 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_specialist_pools_field_headers(): @@ -887,12 +1012,12 @@ def test_list_specialist_pools_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() client.list_specialist_pools(request) @@ -904,10 +1029,7 @@ def test_list_specialist_pools_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -919,13 +1041,15 @@ async def test_list_specialist_pools_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) await client.list_specialist_pools(request) @@ -936,10 +1060,7 @@ async def test_list_specialist_pools_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_specialist_pools_flattened(): @@ -949,23 +1070,21 @@ def test_list_specialist_pools_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_specialist_pools( - parent='parent_value', - ) + client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_specialist_pools_flattened_error(): @@ -977,8 +1096,7 @@ def test_list_specialist_pools_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) @@ -990,24 +1108,24 @@ async def test_list_specialist_pools_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_specialist_pools( - parent='parent_value', - ) + response = await client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -1020,20 +1138,17 @@ async def test_list_specialist_pools_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1042,17 +1157,14 @@ def test_list_specialist_pools_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1065,9 +1177,7 @@ def test_list_specialist_pools_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_specialist_pools(request={}) @@ -1075,18 +1185,16 @@ def test_list_specialist_pools_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in results) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) + def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1095,17 +1203,14 @@ def test_list_specialist_pools_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1116,9 +1221,10 @@ def test_list_specialist_pools_pages(): RuntimeError, ) pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_specialist_pools_async_pager(): client = SpecialistPoolServiceAsyncClient( @@ -1127,8 +1233,10 @@ async def test_list_specialist_pools_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1137,17 +1245,14 @@ async def test_list_specialist_pools_async_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1158,14 +1263,14 @@ async def test_list_specialist_pools_async_pager(): RuntimeError, ) async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in responses) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) + @pytest.mark.asyncio async def test_list_specialist_pools_async_pages(): @@ -1175,8 +1280,10 @@ async def test_list_specialist_pools_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1185,17 +1292,14 @@ async def test_list_specialist_pools_async_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1208,14 +1312,16 @@ async def test_list_specialist_pools_async_pages(): pages = [] async for page_ in (await client.list_specialist_pools(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): +def test_delete_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1224,10 +1330,10 @@ def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_specialist_pool(request) @@ -1246,10 +1352,9 @@ def test_delete_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio'): +async def test_delete_specialist_pool_async(transport: str = "grpc_asyncio"): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1258,11 +1363,11 @@ async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_specialist_pool(request) @@ -1285,13 +1390,13 @@ def test_delete_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_specialist_pool(request) @@ -1302,10 +1407,7 @@ def test_delete_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1317,13 +1419,15 @@ async def test_delete_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_specialist_pool(request) @@ -1334,10 +1438,7 @@ async def test_delete_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_specialist_pool_flattened(): @@ -1347,23 +1448,21 @@ def test_delete_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_specialist_pool( - name='name_value', - ) + client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_specialist_pool_flattened_error(): @@ -1375,8 +1474,7 @@ def test_delete_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) @@ -1388,26 +1486,24 @@ async def test_delete_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_specialist_pool( - name='name_value', - ) + response = await client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -1420,15 +1516,16 @@ async def test_delete_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) -def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): +def test_update_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1437,10 +1534,10 @@ def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.update_specialist_pool(request) @@ -1459,10 +1556,9 @@ def test_update_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio'): +async def test_update_specialist_pool_async(transport: str = "grpc_asyncio"): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1471,11 +1567,11 @@ async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.update_specialist_pool(request) @@ -1498,13 +1594,13 @@ def test_update_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.update_specialist_pool(request) @@ -1516,9 +1612,9 @@ def test_update_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1530,13 +1626,15 @@ async def test_update_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.update_specialist_pool(request) @@ -1548,9 +1646,9 @@ async def test_update_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] def test_update_specialist_pool_flattened(): @@ -1560,16 +1658,16 @@ def test_update_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1577,9 +1675,11 @@ def test_update_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_specialist_pool_flattened_error(): @@ -1592,8 +1692,8 @@ def test_update_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1605,19 +1705,19 @@ async def test_update_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1625,9 +1725,11 @@ async def test_update_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -1641,8 +1743,8 @@ async def test_update_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1653,8 +1755,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1673,8 +1774,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1702,13 +1802,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1719,10 +1822,7 @@ def test_transport_grpc_default(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), ) - assert isinstance( - client.transport, - transports.SpecialistPoolServiceGrpcTransport, - ) + assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) def test_specialist_pool_service_base_transport_error(): @@ -1730,13 +1830,15 @@ def test_specialist_pool_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_specialist_pool_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1745,12 +1847,12 @@ def test_specialist_pool_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_specialist_pool', - 'get_specialist_pool', - 'list_specialist_pools', - 'delete_specialist_pool', - 'update_specialist_pool', - ) + "create_specialist_pool", + "get_specialist_pool", + "list_specialist_pools", + "delete_specialist_pool", + "update_specialist_pool", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1763,23 +1865,28 @@ def test_specialist_pool_service_base_transport(): def test_specialist_pool_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_specialist_pool_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport() @@ -1788,11 +1895,11 @@ def test_specialist_pool_service_base_transport_with_adc(): def test_specialist_pool_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) SpecialistPoolServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1800,60 +1907,75 @@ def test_specialist_pool_service_auth_adc(): def test_specialist_pool_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_host_with_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1862,7 +1984,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1878,26 +2000,30 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) -def test_specialist_pool_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1914,9 +2040,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1925,16 +2049,12 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( def test_specialist_pool_service_grpc_lro_client(): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1942,36 +2062,36 @@ def test_specialist_pool_service_grpc_lro_client(): def test_specialist_pool_service_grpc_lro_async_client(): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_specialist_pool_path(): project = "squid" location = "clam" specialist_pool = "whelk" - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) - actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + actual = SpecialistPoolServiceClient.specialist_pool_path( + project, location, specialist_pool + ) assert expected == actual def test_parse_specialist_pool_path(): expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", - + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", } path = SpecialistPoolServiceClient.specialist_pool_path(**expected) @@ -1979,18 +2099,20 @@ def test_parse_specialist_pool_path(): actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = SpecialistPoolServiceClient.common_billing_account_path(**expected) @@ -1998,18 +2120,18 @@ def test_parse_common_billing_account_path(): actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = SpecialistPoolServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = SpecialistPoolServiceClient.common_folder_path(**expected) @@ -2017,18 +2139,18 @@ def test_parse_common_folder_path(): actual = SpecialistPoolServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = SpecialistPoolServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = SpecialistPoolServiceClient.common_organization_path(**expected) @@ -2036,18 +2158,18 @@ def test_parse_common_organization_path(): actual = SpecialistPoolServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = SpecialistPoolServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = SpecialistPoolServiceClient.common_project_path(**expected) @@ -2055,20 +2177,22 @@ def test_parse_common_project_path(): actual = SpecialistPoolServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = SpecialistPoolServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = SpecialistPoolServiceClient.common_location_path(**expected) @@ -2080,17 +2204,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = SpecialistPoolServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) From 799d865156d413b48592428c289ae6b95e4e0776 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 27 Oct 2020 12:43:47 -0700 Subject: [PATCH 03/12] update beta --- CODE_OF_CONDUCT.md | 123 +- docs/aiplatform_v1beta1/types.rst | 1 + google/cloud/aiplatform_v1beta1/__init__.py | 10 + .../services/dataset_service/async_client.py | 35 +- .../dataset_service/transports/grpc.py | 4 + .../transports/grpc_asyncio.py | 4 + .../services/endpoint_service/async_client.py | 31 +- .../services/endpoint_service/client.py | 4 +- .../endpoint_service/transports/grpc.py | 4 + .../transports/grpc_asyncio.py | 4 + .../services/job_service/async_client.py | 60 +- .../services/job_service/transports/grpc.py | 4 + .../job_service/transports/grpc_asyncio.py | 4 + .../migration_service/async_client.py | 11 +- .../migration_service/transports/grpc.py | 4 + .../transports/grpc_asyncio.py | 4 + .../services/model_service/async_client.py | 30 +- .../services/model_service/transports/grpc.py | 4 + .../model_service/transports/grpc_asyncio.py | 4 + .../services/pipeline_service/async_client.py | 15 +- .../pipeline_service/transports/grpc.py | 4 + .../transports/grpc_asyncio.py | 4 + .../prediction_service/async_client.py | 18 +- .../prediction_service/transports/grpc.py | 4 + .../transports/grpc_asyncio.py | 4 + .../specialist_pool_service/async_client.py | 15 +- .../transports/grpc.py | 4 + .../transports/grpc_asyncio.py | 4 + .../aiplatform_v1beta1/types/__init__.py | 10 + .../types/batch_prediction_job.py | 23 +- .../aiplatform_v1beta1/types/custom_job.py | 28 + .../types/data_labeling_job.py | 7 + .../cloud/aiplatform_v1beta1/types/dataset.py | 13 +- .../aiplatform_v1beta1/types/endpoint.py | 15 +- .../aiplatform_v1beta1/types/explanation.py | 296 +- .../types/explanation_metadata.py | 284 +- .../types/machine_resources.py | 24 +- .../aiplatform_v1beta1/types/operation.py | 6 +- .../types/prediction_service.py | 6 + .../cloud/aiplatform_v1beta1/types/study.py | 97 + .../types/training_pipeline.py | 40 +- scripts/fixup_aiplatform_v1beta1_keywords.py | 1 + synth.metadata | 4 +- .../test_dataset_service.py | 2246 +++++----- .../test_endpoint_service.py | 1592 +++---- .../aiplatform_v1beta1/test_job_service.py | 3640 ++++++++++------- .../test_migration_service.py | 923 ++--- .../aiplatform_v1beta1/test_model_service.py | 2341 ++++++----- .../test_pipeline_service.py | 1266 +++--- .../test_specialist_pool_service.py | 1132 +++-- 50 files changed, 8145 insertions(+), 6266 deletions(-) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index b3d1f60298..039f436812 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,44 +1,95 @@ -# Contributor Code of Conduct +# Code of Conduct -As contributors and maintainers of this project, -and in the interest of fostering an open and welcoming community, -we pledge to respect all people who contribute through reporting issues, -posting feature requests, updating documentation, -submitting pull requests or patches, and other activities. +## Our Pledge -We are committed to making participation in this project -a harassment-free experience for everyone, -regardless of level of experience, gender, gender identity and expression, -sexual orientation, disability, personal appearance, -body size, race, ethnicity, age, religion, or nationality. +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of +experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery -* Personal attacks -* Trolling or insulting/derogatory comments -* Public or private harassment -* Publishing other's private information, -such as physical or electronic -addresses, without explicit permission -* Other unethical or unprofessional conduct. +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions -that are not aligned to this Code of Conduct. -By adopting this Code of Conduct, -project maintainers commit themselves to fairly and consistently -applying these principles to every aspect of managing this project. -Project maintainers who do not follow or enforce the Code of Conduct -may be permanently removed from the project team. - -This code of conduct applies both within project spaces and in public spaces -when an individual is representing the project or its community. - -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported by opening an issue -or contacting one or more of the project maintainers. - -This Code of Conduct is adapted from the [Contributor Covenant](http://contributor-covenant.org), version 1.2.0, -available at [http://contributor-covenant.org/version/1/2/0/](http://contributor-covenant.org/version/1/2/0/) +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, or to ban temporarily or permanently any +contributor for other behaviors that they deem inappropriate, threatening, +offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when the Project +Steward has a reasonable belief that an individual's behavior may have a +negative impact on the project or its community. + +## Conflict Resolution + +We do not believe that all conflict is bad; healthy debate and disagreement +often yield positive results. However, it is never okay to be disrespectful or +to engage in behavior that violates the project’s code of conduct. + +If you see someone violating the code of conduct, you are encouraged to address +the behavior directly with those involved. Many issues can be resolved quickly +and easily, and this gives people more control over the outcome of their +dispute. If you are unable to resolve the matter for any reason, or if the +behavior is threatening or harassing, report it. We are dedicated to providing +an environment where participants feel welcome and safe. + + +Reports should be directed to *googleapis-stewards@google.com*, the +Project Steward(s) for *Google Cloud Client Libraries*. It is the Project Steward’s duty to +receive and address reported violations of the code of conduct. They will then +work with a committee consisting of representatives from the Open Source +Programs Office and the Google Open Source Strategy team. If for any reason you +are uncomfortable reaching out to the Project Steward, please email +opensource@google.com. + +We will investigate every complaint, but you may not receive a direct response. +We will use our discretion in determining when and how to follow up on reported +incidents, which may range from not taking action to permanent expulsion from +the project and project-sponsored spaces. We will notify the accused of the +report and provide them an opportunity to discuss it before any action is taken. +The identity of the reporter will be omitted from the details of the report +supplied to the accused. In potentially harmful situations, such as ongoing +harassment or threats to anyone's safety, we may take action without notice. + +## Attribution + +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, +available at +https://www.contributor-covenant.org/version/1/4/code-of-conduct.html \ No newline at end of file diff --git a/docs/aiplatform_v1beta1/types.rst b/docs/aiplatform_v1beta1/types.rst index 3f8a7c9d65..19bab68ada 100644 --- a/docs/aiplatform_v1beta1/types.rst +++ b/docs/aiplatform_v1beta1/types.rst @@ -3,3 +3,4 @@ Types for Google Cloud Aiplatform v1beta1 API .. automodule:: google.cloud.aiplatform_v1beta1.types :members: + :show-inheritance: diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 7d45ebe371..f49f90f5eb 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -81,8 +81,12 @@ from .types.explanation import Explanation from .types.explanation import ExplanationParameters from .types.explanation import ExplanationSpec +from .types.explanation import FeatureNoiseSigma +from .types.explanation import IntegratedGradientsAttribution from .types.explanation import ModelExplanation from .types.explanation import SampledShapleyAttribution +from .types.explanation import SmoothGradConfig +from .types.explanation import XraiAttribution from .types.explanation_metadata import ExplanationMetadata from .types.hyperparameter_tuning_job import HyperparameterTuningJob from .types.io import BigQueryDestination @@ -118,6 +122,7 @@ from .types.machine_resources import AutomaticResources from .types.machine_resources import BatchDedicatedResources from .types.machine_resources import DedicatedResources +from .types.machine_resources import DiskSpec from .types.machine_resources import MachineSpec from .types.machine_resources import ResourcesConsumed from .types.manual_batch_tuning_parameters import ManualBatchTuningParameters @@ -240,6 +245,7 @@ "DeployModelResponse", "DeployedModel", "DeployedModelRef", + "DiskSpec", "Endpoint", "EndpointServiceClient", "EnvVar", @@ -256,6 +262,7 @@ "ExportModelOperationMetadata", "ExportModelRequest", "ExportModelResponse", + "FeatureNoiseSigma", "FilterSplit", "FractionSplit", "GcsDestination", @@ -279,6 +286,7 @@ "ImportDataRequest", "ImportDataResponse", "InputDataConfig", + "IntegratedGradientsAttribution", "JobServiceClient", "JobState", "ListAnnotationsRequest", @@ -335,6 +343,7 @@ "Scheduling", "SearchMigratableResourcesRequest", "SearchMigratableResourcesResponse", + "SmoothGradConfig", "SpecialistPool", "SpecialistPoolServiceClient", "StudySpec", @@ -355,5 +364,6 @@ "UploadModelResponse", "UserActionReference", "WorkerPoolSpec", + "XraiAttribution", "DatasetServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 984683b4ac..775558e3b1 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -207,7 +207,8 @@ async def create_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, dataset]): + has_flattened_params = any([parent, dataset]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -288,7 +289,8 @@ async def get_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -373,7 +375,8 @@ async def update_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([dataset, update_mask]): + has_flattened_params = any([dataset, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -451,7 +454,8 @@ async def list_datasets( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -544,7 +548,8 @@ async def delete_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -635,7 +640,8 @@ async def import_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, import_configs]): + has_flattened_params = any([name, import_configs]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -648,8 +654,9 @@ async def import_data( if name is not None: request.name = name - if import_configs is not None: - request.import_configs = import_configs + + if import_configs: + request.import_configs.extend(import_configs) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -727,7 +734,8 @@ async def export_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, export_config]): + has_flattened_params = any([name, export_config]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -812,7 +820,8 @@ async def list_data_items( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -891,7 +900,8 @@ async def get_annotation_spec( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -967,7 +977,8 @@ async def list_annotations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 801346ca58..2647c4bd9c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -104,6 +104,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -111,6 +113,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -147,6 +150,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index c0067cb997..1f22b10f3e 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -149,6 +149,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -156,6 +158,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -192,6 +195,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 3801c42a08..9056e7a149 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -198,7 +198,8 @@ async def create_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, endpoint]): + has_flattened_params = any([parent, endpoint]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -280,7 +281,8 @@ async def get_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -355,7 +357,8 @@ async def list_endpoints( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -440,7 +443,8 @@ async def update_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, update_mask]): + has_flattened_params = any([endpoint, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -531,7 +535,8 @@ async def delete_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -649,7 +654,8 @@ async def deploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model, traffic_split]): + has_flattened_params = any([endpoint, deployed_model, traffic_split]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -664,8 +670,9 @@ async def deploy_model( request.endpoint = endpoint if deployed_model is not None: request.deployed_model = deployed_model - if traffic_split is not None: - request.traffic_split = traffic_split + + if traffic_split: + request.traffic_split.update(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -762,7 +769,8 @@ async def undeploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model_id, traffic_split]): + has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -777,8 +785,9 @@ async def undeploy_model( request.endpoint = endpoint if deployed_model_id is not None: request.deployed_model_id = deployed_model_id - if traffic_split is not None: - request.traffic_split = traffic_split + + if traffic_split: + request.traffic_split.update(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index fbbe7219e4..5ea003b827 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -887,7 +887,7 @@ def deploy_model( request.deployed_model = deployed_model if traffic_split: - request.traffic_split = traffic_split + request.traffic_split.update(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -1003,7 +1003,7 @@ def undeploy_model( request.deployed_model_id = deployed_model_id if traffic_split: - request.traffic_split = traffic_split + request.traffic_split.update(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index a7ca0d8b13..70915facf0 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -103,6 +103,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -110,6 +112,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -146,6 +149,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index 7d743ebb56..f4e362281b 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -148,6 +148,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -155,6 +157,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -191,6 +194,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index ca5f400eaa..da6fafd965 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -230,7 +230,8 @@ async def create_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, custom_job]): + has_flattened_params = any([parent, custom_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -309,7 +310,8 @@ async def get_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -384,7 +386,8 @@ async def list_custom_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -477,7 +480,8 @@ async def delete_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -562,7 +566,8 @@ async def cancel_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -640,7 +645,8 @@ async def create_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, data_labeling_job]): + has_flattened_params = any([parent, data_labeling_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -715,7 +721,8 @@ async def get_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -789,7 +796,8 @@ async def list_data_labeling_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -883,7 +891,8 @@ async def delete_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -958,7 +967,8 @@ async def cancel_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1038,7 +1048,8 @@ async def create_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, hyperparameter_tuning_job]): + has_flattened_params = any([parent, hyperparameter_tuning_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1115,7 +1126,8 @@ async def get_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1190,7 +1202,8 @@ async def list_hyperparameter_tuning_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1284,7 +1297,8 @@ async def delete_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1372,7 +1386,8 @@ async def cancel_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1456,7 +1471,8 @@ async def create_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, batch_prediction_job]): + has_flattened_params = any([parent, batch_prediction_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1536,7 +1552,8 @@ async def get_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1611,7 +1628,8 @@ async def list_batch_prediction_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1706,7 +1724,8 @@ async def delete_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -1792,7 +1811,8 @@ async def cancel_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index 246b11b5d6..f4b610bd53 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -118,6 +118,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -125,6 +127,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -161,6 +164,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index 428b37f268..83cc826484 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -163,6 +163,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -170,6 +172,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -206,6 +209,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index c9008dc298..af13c4d4fb 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -206,7 +206,8 @@ async def search_migratable_resources( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -302,7 +303,8 @@ async def batch_migrate_resources( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, migrate_resource_requests]): + has_flattened_params = any([parent, migrate_resource_requests]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -315,8 +317,9 @@ async def batch_migrate_resources( if parent is not None: request.parent = parent - if migrate_resource_requests is not None: - request.migrate_resource_requests = migrate_resource_requests + + if migrate_resource_requests: + request.migrate_resource_requests.extend(migrate_resource_requests) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index 50d81c4ab3..efd4c4b6a4 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -105,6 +105,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -112,6 +114,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -148,6 +151,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index 1450fbf2b5..ba038f57c5 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -150,6 +150,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -157,6 +159,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -193,6 +196,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index aa56f7d953..81c1f9cb51 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -214,7 +214,8 @@ async def upload_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, model]): + has_flattened_params = any([parent, model]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -293,7 +294,8 @@ async def get_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -368,7 +370,8 @@ async def list_models( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -452,7 +455,8 @@ async def update_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([model, update_mask]): + has_flattened_params = any([model, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -545,7 +549,8 @@ async def delete_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -640,7 +645,8 @@ async def export_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, output_config]): + has_flattened_params = any([name, output_config]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -725,7 +731,8 @@ async def get_model_evaluation( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -800,7 +807,8 @@ async def list_model_evaluations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -881,7 +889,8 @@ async def get_model_evaluation_slice( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -957,7 +966,8 @@ async def list_model_evaluation_slices( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index df720617a7..442b665d3a 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -107,6 +107,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -114,6 +116,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -150,6 +153,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index ffe89774ef..13e9848290 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -152,6 +152,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -159,6 +161,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -195,6 +198,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 22777c2405..d361b05e21 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -209,7 +209,8 @@ async def create_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, training_pipeline]): + has_flattened_params = any([parent, training_pipeline]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -288,7 +289,8 @@ async def get_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -363,7 +365,8 @@ async def list_training_pipelines( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -457,7 +460,8 @@ async def delete_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -544,7 +548,8 @@ async def cancel_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 66580ae42e..4fc6389449 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -108,6 +108,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -115,6 +117,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -151,6 +154,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index a66285f6dc..2e6f51e1a3 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -153,6 +153,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -160,6 +162,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -196,6 +199,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index 606ce0f46b..c82146bafa 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -206,7 +206,8 @@ async def predict( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, instances, parameters]): + has_flattened_params = any([endpoint, instances, parameters]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -219,11 +220,12 @@ async def predict( if endpoint is not None: request.endpoint = endpoint - if instances is not None: - request.instances = instances if parameters is not None: request.parameters = parameters + if instances: + request.instances.extend(instances) + # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( @@ -328,9 +330,8 @@ async def explain( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any( - [endpoint, instances, parameters, deployed_model_id] - ): + has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -343,13 +344,14 @@ async def explain( if endpoint is not None: request.endpoint = endpoint - if instances is not None: - request.instances = instances if parameters is not None: request.parameters = parameters if deployed_model_id is not None: request.deployed_model_id = deployed_model_id + if instances: + request.instances.extend(instances) + # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = gapic_v1.method_async.wrap_method( diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 6c4cdf8d12..1a102e1a61 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -101,6 +101,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -108,6 +110,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -144,6 +147,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index f8d06bc047..a0785007db 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -146,6 +146,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -153,6 +155,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -189,6 +192,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index 507ce92262..77f40bd4ad 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -215,7 +215,8 @@ async def create_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, specialist_pool]): + has_flattened_params = any([parent, specialist_pool]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -307,7 +308,8 @@ async def get_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -382,7 +384,8 @@ async def list_specialist_pools( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -476,7 +479,8 @@ async def delete_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -572,7 +576,8 @@ async def update_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([specialist_pool, update_mask]): + has_flattened_params = any([specialist_pool, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 18bdaaa035..2d1442ae33 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -109,6 +109,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -116,6 +118,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -152,6 +155,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index e2763c647f..7d038edc4f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -154,6 +154,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -161,6 +163,7 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn( "api_mtls_endpoint and client_cert_source are deprecated", @@ -197,6 +200,7 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 82fa939f8c..97e5625d20 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -37,6 +37,7 @@ AutomaticResources, BatchDedicatedResources, ResourcesConsumed, + DiskSpec, ) from .deployed_model_ref import DeployedModelRef from .env_var import EnvVar @@ -48,6 +49,10 @@ ExplanationSpec, ExplanationParameters, SampledShapleyAttribution, + IntegratedGradientsAttribution, + XraiAttribution, + SmoothGradConfig, + FeatureNoiseSigma, ) from .model import ( Model, @@ -231,6 +236,7 @@ "AutomaticResources", "BatchDedicatedResources", "ResourcesConsumed", + "DiskSpec", "DeployedModelRef", "EnvVar", "ExplanationMetadata", @@ -240,6 +246,10 @@ "ExplanationSpec", "ExplanationParameters", "SampledShapleyAttribution", + "IntegratedGradientsAttribution", + "XraiAttribution", + "SmoothGradConfig", + "FeatureNoiseSigma", "Model", "PredictSchemata", "ModelContainerSpec", diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 2f464e6c8f..64892b8271 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -98,13 +98,22 @@ class BatchPredictionJob(proto.Message): Generate explanation along with the batch prediction results. - This can only be set to true for AutoML tabular Models, and - only when the output destination is BigQuery. When it's - true, the batch prediction output will include a column - named ``explanation``. The value is a struct that conforms - to the - ``Explanation`` - object. + When it's true, the batch prediction output will change + based on the [output + format][BatchPredictionJob.output_config.predictions_format]: + + - ``bigquery``: output will include a column named + ``explanation``. The value is a struct that conforms to + the + ``Explanation`` + object. + - ``jsonl``: The JSON objects on each line will include an + additional entry keyed ``explanation``. The value of the + entry is a JSON object that conforms to the + ``Explanation`` + object. + - ``csv``: Generating explanations for CSV format is not + supported. output_info (~.batch_prediction_job.BatchPredictionJob.OutputInfo): Output only. Information further describing the output of this job. diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index c8147f9d70..2d8745538c 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -115,6 +115,24 @@ class CustomJobSpec(proto.Message): including machine type and Docker image. scheduling (~.custom_job.Scheduling): Scheduling options for a CustomJob. + service_account (str): + Specifies the service account for workload + run-as account. Users submitting jobs must have + act-as permission on this run-as account. + network (str): + The full name of the Compute Engine + `network `__ + to which the Job should be peered. For example, + projects/12345/global/networks/myVPC. + + [Format](https://cloud.google.com/compute/docs/reference/rest/v1/networks/insert) + is of the form projects/{project}/global/networks/{network}. + Where {project} is a project number, as in '12345', and + {network} is network name. + + Private services access must already be configured for the + network. If left unspecified, the job is not peered with any + network. base_output_directory (~.io.GcsDestination): The Google Cloud Storage location to store the output of this CustomJob or HyperparameterTuningJob. For @@ -154,6 +172,10 @@ class CustomJobSpec(proto.Message): scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) + service_account = proto.Field(proto.STRING, number=4) + + network = proto.Field(proto.STRING, number=5) + base_output_directory = proto.Field( proto.MESSAGE, number=6, message=io.GcsDestination, ) @@ -173,6 +195,8 @@ class WorkerPoolSpec(proto.Message): replica_count (int): Required. The number of worker replicas to use for this worker pool. + disk_spec (~.machine_resources.DiskSpec): + Disk spec. """ container_spec = proto.Field( @@ -189,6 +213,10 @@ class WorkerPoolSpec(proto.Message): replica_count = proto.Field(proto.INT64, number=2) + disk_spec = proto.Field( + proto.MESSAGE, number=5, message=machine_resources.DiskSpec, + ) + class ContainerSpec(proto.Message): r"""The spec of a Container. diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index d94efba1b0..af1bcdd871 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -21,6 +21,7 @@ from google.cloud.aiplatform_v1beta1.types import job_state from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore from google.type import money_pb2 as money # type: ignore @@ -98,6 +99,10 @@ class DataLabelingJob(proto.Message): update_time (~.timestamp.Timestamp): Output only. Timestamp when this DataLabelingJob was updated most recently. + error (~.status.Status): + Output only. DataLabelingJob errors. It is only populated + when job's state is ``JOB_STATE_FAILED`` or + ``JOB_STATE_CANCELLED``. labels (Sequence[~.data_labeling_job.DataLabelingJob.LabelsEntry]): The labels with user-defined metadata to organize your DataLabelingJobs. @@ -153,6 +158,8 @@ class DataLabelingJob(proto.Message): update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) + error = proto.Field(proto.MESSAGE, number=22, message=status.Status,) + labels = proto.MapField(proto.STRING, proto.STRING, number=11) specialist_pools = proto.RepeatedField(proto.STRING, number=16) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 5840df17f3..76f6462f40 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -150,12 +150,13 @@ class ExportDataConfig(proto.Message): written to. In the given directory a new directory will be created with name: ``export-data--`` - where timestamp is in YYYYMMDDHHMMSS format. All export - output will be written into that directory. Inside that - directory, annotations with the same schema will be grouped - into sub directories which are named with the corresponding - annotations' schema title. Inside these sub directories, a - schema.yaml will be created to describe the output format. + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format. All export output will be written into that + directory. Inside that directory, annotations with the same + schema will be grouped into sub directories which are named + with the corresponding annotations' schema title. Inside + these sub directories, a schema.yaml will be created to + describe the output format. annotations_filter (str): A filter on Annotations of the Dataset. Only Annotations on to-be-exported DataItems(specified by [data_items_filter][]) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 07f6a2c61b..f1ba6ed85d 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -145,9 +145,16 @@ class DeployedModel(proto.Message): ``Model.explanation_spec`` must be populated, otherwise explanation for this Model is not allowed. - - Currently, only AutoML tabular Models support - explanation_spec. + service_account (str): + The service account that the DeployedModel's container runs + as. Specify the email address of the service account. If + this service account is not specified, the container runs as + a service account that doesn't have access to the resource + project. + + Users deploying the Model must have the + ``iam.serviceAccounts.actAs`` permission on this service + account. enable_container_logging (bool): If true, the container of the DeployedModel instances will send ``stderr`` and ``stdout`` streams to Stackdriver @@ -192,6 +199,8 @@ class DeployedModel(proto.Message): proto.MESSAGE, number=9, message=explanation.ExplanationSpec, ) + service_account = proto.Field(proto.STRING, number=11) + enable_container_logging = proto.Field(proto.BOOL, number=12) enable_access_logging = proto.Field(proto.BOOL, number=13) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 06b930d90c..7a495fff1e 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -31,18 +31,20 @@ "ExplanationSpec", "ExplanationParameters", "SampledShapleyAttribution", + "IntegratedGradientsAttribution", + "XraiAttribution", + "SmoothGradConfig", + "FeatureNoiseSigma", }, ) class Explanation(proto.Message): r"""Explanation of a prediction (provided in - ``PredictResponse.predictions`` - ) produced by the Model on a given + ``PredictResponse.predictions``) + produced by the Model on a given ``instance``. - Currently, only AutoML tabular Models support explanation. - Attributes: attributions (Sequence[~.explanation.Attribution]): Output only. Feature attributions grouped by predicted @@ -57,6 +59,16 @@ class Explanation(proto.Message): ``Attribution.output_index`` can be used to identify which output this attribution is explaining. + + If users set + ``ExplanationParameters.top_k``, + the attributions are sorted by + ``instance_output_value`` + in descending order. If + ``ExplanationParameters.output_indices`` + is specified, the attributions are stored by + ``Attribution.output_index`` + in the same order as they appear in the output_indices. """ attributions = proto.RepeatedField(proto.MESSAGE, number=1, message="Attribution",) @@ -65,8 +77,6 @@ class Explanation(proto.Message): class ModelExplanation(proto.Message): r"""Aggregated explanation metrics for a Model over a set of instances. - Currently, only AutoML tabular Models support aggregated - explanation. Attributes: mean_attributions (Sequence[~.explanation.Attribution]): @@ -115,9 +125,8 @@ class Attribution(proto.Message): The field name of the output is determined by the key in ``ExplanationMetadata.outputs``. - If the Model predicted output is a tensor value (for - example, an ndarray), this is the value in the output - located by + If the Model's predicted output has multiple dimensions + (rank > 1), this is the value in the output located by ``output_index``. If there are multiple baselines, their output values are @@ -128,16 +137,15 @@ class Attribution(proto.Message): name of the output is determined by the key in ``ExplanationMetadata.outputs``. - If the Model predicted output is a tensor value (for - example, an ndarray), this is the value in the output - located by + If the Model predicted output has multiple dimensions, this + is the value in the output located by ``output_index``. feature_attributions (~.struct.Value): Output only. Attributions of each explained feature. Features are extracted from the [prediction instances][google.cloud.aiplatform.v1beta1.ExplainRequest.instances] - according to [explanation input - metadata][google.cloud.aiplatform.v1beta1.ExplanationMetadata.inputs]. + according to [explanation metadata for + inputs][google.cloud.aiplatform.v1beta1.ExplanationMetadata.inputs]. The value is a struct, whose keys are the name of the feature. The values are how much the feature in the @@ -175,11 +183,11 @@ class Attribution(proto.Message): output. If the prediction output is a scalar value, output_index is - not populated. If the prediction output is a tensor value - (for example, an ndarray), the length of output_index is the - same as the number of dimensions of the output. The i-th - element in output_index is the element index of the i-th - dimension of the output vector. Indexes start from 0. + not populated. If the prediction output has multiple + dimensions, the length of the output_index list is the same + as the number of dimensions of the output. The i-th element + in output_index is the element index of the i-th dimension + of the output vector. Indices start from 0. output_display_name (str): Output only. The display name of the output identified by ``output_index``, @@ -196,11 +204,29 @@ class Attribution(proto.Message): caused by approximation used in the explanation method. Lower value means more precise attributions. - For Sampled Shapley - ``attribution``, - increasing - ``path_count`` - might reduce the error. + - For [Sampled Shapley + attribution][ExplanationParameters.sampled_shapley_attribution], + increasing + ``path_count`` + may reduce the error. + - For [Integrated Gradients + attribution][ExplanationParameters.integrated_gradients_attribution], + increasing + ``step_count`` + may reduce the error. + - For [XRAI + attribution][ExplanationParameters.xrai_attribution], + increasing + ``step_count`` + may reduce the error. + + Refer to AI Explanations Whitepaper for more details: + + https://storage.googleapis.com/cloud-ai-whitepapers/AI%20Explainability%20Whitepaper.pdf + output_name (str): + Output only. Name of the explain output. Specified as the + key in + ``ExplanationMetadata.outputs``. """ baseline_output_value = proto.Field(proto.DOUBLE, number=1) @@ -215,10 +241,11 @@ class Attribution(proto.Message): approximation_error = proto.Field(proto.DOUBLE, number=6) + output_name = proto.Field(proto.STRING, number=7) + class ExplanationSpec(proto.Message): r"""Specification of Model explanation. - Currently, only AutoML tabular Models support explanation. Attributes: parameters (~.explanation.ExplanationParameters): @@ -245,13 +272,69 @@ class ExplanationParameters(proto.Message): Shapley values for features that contribute to the label being predicted. A sampling strategy is used to approximate the value rather than - considering all subsets of features. + considering all subsets of features. Refer to + this paper for model details: + https://arxiv.org/abs/1306.4265. + integrated_gradients_attribution (~.explanation.IntegratedGradientsAttribution): + An attribution method that computes Aumann- + hapley values taking advantage of the model's + fully differentiable structure. Refer to this + paper for more details: + https://arxiv.org/abs/1703.01365 + xrai_attribution (~.explanation.XraiAttribution): + An attribution method that redistributes + Integrated Gradients attribution to segmented + regions, taking advantage of the model's fully + differentiable structure. Refer to this paper + for more details: + https://arxiv.org/abs/1906.02825 + XRAI currently performs better on natural + images, like a picture of a house or an animal. + If the images are taken in artificial + environments, like a lab or manufacturing line, + or from diagnostic equipment, like x-rays or + quality-control cameras, use Integrated + Gradients instead. + top_k (int): + If populated, returns attributions for top K + indices of outputs (defaults to 1). Only applies + to Models that predicts more than one outputs + (e,g, multi-class Models). When set to -1, + returns explanations for all outputs. + output_indices (~.struct.ListValue): + If populated, only returns attributions that have + ``output_index`` contained in + output_indices. It must be an ndarray of integers, with the + same shape of the output it's explaining. + + If not populated, returns attributions for + ``top_k`` + indices of outputs. If neither top_k nor output_indeices is + populated, returns the argmax index of the outputs. + + Only applicable to Models that predict multiple outputs + (e,g, multi-class Models that predict multiple classes). """ sampled_shapley_attribution = proto.Field( - proto.MESSAGE, number=1, message="SampledShapleyAttribution", + proto.MESSAGE, number=1, oneof="method", message="SampledShapleyAttribution", ) + integrated_gradients_attribution = proto.Field( + proto.MESSAGE, + number=2, + oneof="method", + message="IntegratedGradientsAttribution", + ) + + xrai_attribution = proto.Field( + proto.MESSAGE, number=3, oneof="method", message="XraiAttribution", + ) + + top_k = proto.Field(proto.INT32, number=4) + + output_indices = proto.Field(proto.MESSAGE, number=5, message=struct.ListValue,) + class SampledShapleyAttribution(proto.Message): r"""An attribution method that approximates Shapley values for @@ -270,4 +353,163 @@ class SampledShapleyAttribution(proto.Message): path_count = proto.Field(proto.INT32, number=1) +class IntegratedGradientsAttribution(proto.Message): + r"""An attribution method that computes the Aumann-Shapley value + taking advantage of the model's fully differentiable structure. + Refer to this paper for more details: + https://arxiv.org/abs/1703.01365 + + Attributes: + step_count (int): + Required. The number of steps for approximating the path + integral. A good value to start is 50 and gradually increase + until the sum to diff property is within the desired error + range. + + Valid range of its value is [1, 100], inclusively. + smooth_grad_config (~.explanation.SmoothGradConfig): + Config for SmoothGrad approximation of + gradients. + When enabled, the gradients are approximated by + averaging the gradients from noisy samples in + the vicinity of the inputs. Adding noise can + help improve the computed gradients. Refer to + this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf + """ + + step_count = proto.Field(proto.INT32, number=1) + + smooth_grad_config = proto.Field( + proto.MESSAGE, number=2, message="SmoothGradConfig", + ) + + +class XraiAttribution(proto.Message): + r"""An explanation method that redistributes Integrated Gradients + attributions to segmented regions, taking advantage of the model's + fully differentiable structure. Refer to this paper for more + details: https://arxiv.org/abs/1906.02825 + + Only supports image Models (``modality`` is + IMAGE). + + Attributes: + step_count (int): + Required. The number of steps for approximating the path + integral. A good value to start is 50 and gradually increase + until the sum to diff property is met within the desired + error range. + + Valid range of its value is [1, 100], inclusively. + smooth_grad_config (~.explanation.SmoothGradConfig): + Config for SmoothGrad approximation of + gradients. + When enabled, the gradients are approximated by + averaging the gradients from noisy samples in + the vicinity of the inputs. Adding noise can + help improve the computed gradients. Refer to + this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf + """ + + step_count = proto.Field(proto.INT32, number=1) + + smooth_grad_config = proto.Field( + proto.MESSAGE, number=2, message="SmoothGradConfig", + ) + + +class SmoothGradConfig(proto.Message): + r"""Config for SmoothGrad approximation of gradients. + When enabled, the gradients are approximated by averaging the + gradients from noisy samples in the vicinity of the inputs. + Adding noise can help improve the computed gradients. Refer to + this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf + + Attributes: + noise_sigma (float): + This is a single float value and will be used to add noise + to all the features. Use this field when all features are + normalized to have the same distribution: scale to range [0, + 1], [-1, 1] or z-scoring, where features are normalized to + have 0-mean and 1-variance. Refer to this doc for more + details about normalization: + + https://developers.google.com/machine-learning/data-prep/transform/normalization. + + For best results the recommended value is about 10% - 20% of + the standard deviation of the input feature. Refer to + section 3.2 of the SmoothGrad paper: + https://arxiv.org/pdf/1706.03825.pdf. Defaults to 0.1. + + If the distribution is different per feature, set + ``feature_noise_sigma`` + instead for each feature. + feature_noise_sigma (~.explanation.FeatureNoiseSigma): + This is similar to + ``noise_sigma``, + but provides additional flexibility. A separate noise sigma + can be provided for each feature, which is useful if their + distributions are different. No noise is added to features + that are not set. If this field is unset, + ``noise_sigma`` + will be used for all features. + noisy_sample_count (int): + The number of gradient samples to use for approximation. The + higher this number, the more accurate the gradient is, but + the runtime complexity increases by this factor as well. + Valid range of its value is [1, 50]. Defaults to 3. + """ + + noise_sigma = proto.Field(proto.FLOAT, number=1, oneof="GradientNoiseSigma") + + feature_noise_sigma = proto.Field( + proto.MESSAGE, + number=2, + oneof="GradientNoiseSigma", + message="FeatureNoiseSigma", + ) + + noisy_sample_count = proto.Field(proto.INT32, number=3) + + +class FeatureNoiseSigma(proto.Message): + r"""Noise sigma by features. Noise sigma represents the standard + deviation of the gaussian kernel that will be used to add noise + to interpolated inputs prior to computing gradients. + + Attributes: + noise_sigma (Sequence[~.explanation.FeatureNoiseSigma.NoiseSigmaForFeature]): + Noise sigma per feature. No noise is added to + features that are not set. + """ + + class NoiseSigmaForFeature(proto.Message): + r"""Noise sigma for a single feature. + + Attributes: + name (str): + The name of the input feature for which noise sigma is + provided. The features are defined in [explanation metadata + inputs][google.cloud.aiplatform.v1beta1.ExplanationMetadata.inputs]. + sigma (float): + This represents the standard deviation of the Gaussian + kernel that will be used to add noise to the feature prior + to computing gradients. Similar to + ``noise_sigma`` + but represents the noise added to the current feature. + Defaults to 0.1. + """ + + name = proto.Field(proto.STRING, number=1) + + sigma = proto.Field(proto.FLOAT, number=2) + + noise_sigma = proto.RepeatedField( + proto.MESSAGE, number=1, message=NoiseSigmaForFeature, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index cc60c125be..38520669ef 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -40,12 +40,24 @@ class ExplanationMetadata(proto.Message): which has the name specified as the key in ``ExplanationMetadata.inputs``. The baseline of the empty feature is chosen by AI Platform. + + For AI Platform provided Tensorflow images, the key can be + any friendly name of the feature . Once specified, [ + featureAttributions][Attribution.feature_attributions] will + be keyed by this key (if not grouped with another feature). + + For custom images, the key must match with the key in + ``instance``[]. outputs (Sequence[~.explanation_metadata.ExplanationMetadata.OutputsEntry]): Required. Map from output names to output metadata. - Keys are the name of the output field in the - prediction to be explained. Currently only one - key is allowed. + For AI Platform provided Tensorflow images, keys + can be any string user defines. + + For custom images, keys are the name of the + output field in the prediction to be explained. + + Currently only one key is allowed. feature_attributions_schema_uri (str): Points to a YAML file stored on Google Cloud Storage describing the format of the [feature @@ -62,6 +74,11 @@ class ExplanationMetadata(proto.Message): class InputMetadata(proto.Message): r"""Metadata of the input of a feature. + Fields other than + ``InputMetadata.input_baselines`` + are applicable only for Models that are using AI Platform-provided + images for Tensorflow. + Attributes: input_baselines (Sequence[~.struct.Value]): Baseline inputs for this feature. @@ -71,20 +88,271 @@ class InputMetadata(proto.Message): specified, AI Platform returns the average attributions across them in [Attributions.baseline_attribution][]. - The element of the baselines must be in the same format as - the feature's input in the + For AI Platform provided Tensorflow images (both 1.x and + 2.x), the shape of each baseline must match the shape of the + input tensor. If a scalar is provided, we broadcast to the + same shape as the input tensor. + + For custom images, the element of the baselines must be in + the same format as the feature's input in the ``instance``[]. The schema of any single instance may be specified via Endpoint's DeployedModels' [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. + input_tensor_name (str): + Name of the input tensor for this feature. + Required and is only applicable to AI Platform + provided images for Tensorflow. + encoding (~.explanation_metadata.ExplanationMetadata.InputMetadata.Encoding): + Defines how the feature is encoded into the + input tensor. Defaults to IDENTITY. + modality (str): + Modality of the feature. Valid values are: + numeric, image. Defaults to numeric. + feature_value_domain (~.explanation_metadata.ExplanationMetadata.InputMetadata.FeatureValueDomain): + The domain details of the input feature + value. Like min/max, original mean or standard + deviation if normalized. + indices_tensor_name (str): + Specifies the index of the values of the input tensor. + Required when the input tensor is a sparse representation. + Refer to Tensorflow documentation for more details: + https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor. + dense_shape_tensor_name (str): + Specifies the shape of the values of the input if the input + is a sparse representation. Refer to Tensorflow + documentation for more details: + https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor. + index_feature_mapping (Sequence[str]): + A list of feature names for each index in the input tensor. + Required when the input + ``InputMetadata.encoding`` + is BAG_OF_FEATURES, BAG_OF_FEATURES_SPARSE, INDICATOR. + encoded_tensor_name (str): + Encoded tensor is a transformation of the input tensor. Must + be provided if choosing [Integrated Gradients + attribution][ExplanationParameters.integrated_gradients_attribution] + or [XRAI + attribution][google.cloud.aiplatform.v1beta1.ExplanationParameters.xrai_attribution] + and the input tensor is not differentiable. + + An encoded tensor is generated if the input tensor is + encoded by a lookup table. + encoded_baselines (Sequence[~.struct.Value]): + A list of baselines for the encoded tensor. + The shape of each baseline should match the + shape of the encoded tensor. If a scalar is + provided, AI Platform broadcast to the same + shape as the encoded tensor. + visualization (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization): + Visualization configurations for image + explanation. + group_name (str): + Name of the group that the input belongs to. Features with + the same group name will be treated as one feature when + computing attributions. Features grouped together can have + different shapes in value. If provided, there will be one + single attribution generated in [ + featureAttributions][Attribution.feature_attributions], + keyed by the group name. """ + class Encoding(proto.Enum): + r"""Defines how the feature is encoded to [encoded_tensor][]. Defaults + to IDENTITY. + """ + ENCODING_UNSPECIFIED = 0 + IDENTITY = 1 + BAG_OF_FEATURES = 2 + BAG_OF_FEATURES_SPARSE = 3 + INDICATOR = 4 + COMBINED_EMBEDDING = 5 + CONCAT_EMBEDDING = 6 + + class FeatureValueDomain(proto.Message): + r"""Domain details of the input feature value. Provides numeric + information about the feature, such as its range (min, max). If the + feature has been pre-processed, for example with z-scoring, then it + provides information about how to recover the original feature. For + example, if the input feature is an image and it has been + pre-processed to obtain 0-mean and stddev = 1 values, then + original_mean, and original_stddev refer to the mean and stddev of + the original feature (e.g. image tensor) from which input feature + (with mean = 0 and stddev = 1) was obtained. + + Attributes: + min_ (float): + The minimum permissible value for this + feature. + max_ (float): + The maximum permissible value for this + feature. + original_mean (float): + If this input feature has been normalized to a mean value of + 0, the original_mean specifies the mean value of the domain + prior to normalization. + original_stddev (float): + If this input feature has been normalized to a standard + deviation of 1.0, the original_stddev specifies the standard + deviation of the domain prior to normalization. + """ + + min_ = proto.Field(proto.FLOAT, number=1) + + max_ = proto.Field(proto.FLOAT, number=2) + + original_mean = proto.Field(proto.FLOAT, number=3) + + original_stddev = proto.Field(proto.FLOAT, number=4) + + class Visualization(proto.Message): + r"""Visualization configurations for image explanation. + + Attributes: + type_ (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.Type): + Type of the image visualization. Only applicable to + [Integrated Gradients attribution] + [ExplanationParameters.integrated_gradients_attribution]. + OUTLINES shows regions of attribution, while PIXELS shows + per-pixel attribution. Defaults to OUTLINES. + polarity (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.Polarity): + Whether to only highlight pixels with + positive contributions, negative or both. + Defaults to POSITIVE. + color_map (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.ColorMap): + The color scheme used for the highlighted areas. + + Defaults to PINK_GREEN for [Integrated Gradients + attribution][ExplanationParameters.integrated_gradients_attribution], + which shows positive attributions in green and negative in + pink. + + Defaults to VIRIDIS for [XRAI + attribution][google.cloud.aiplatform.v1beta1.ExplanationParameters.xrai_attribution], + which highlights the most influential regions in yellow and + the least influential in blue. + clip_percent_upperbound (float): + Excludes attributions above the specified percentile from + the highlighted areas. Using the clip_percent_upperbound and + clip_percent_lowerbound together can be useful for filtering + out noise and making it easier to see areas of strong + attribution. Defaults to 99.9. + clip_percent_lowerbound (float): + Excludes attributions below the specified + percentile, from the highlighted areas. Defaults + to 35. + overlay_type (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.OverlayType): + How the original image is displayed in the + visualization. Adjusting the overlay can help + increase visual clarity if the original image + makes it difficult to view the visualization. + Defaults to NONE. + """ + + class Type(proto.Enum): + r"""Type of the image visualization. Only applicable to [Integrated + Gradients attribution] + [ExplanationParameters.integrated_gradients_attribution]. + """ + TYPE_UNSPECIFIED = 0 + PIXELS = 1 + OUTLINES = 2 + + class Polarity(proto.Enum): + r"""Whether to only highlight pixels with positive contributions, + negative or both. Defaults to POSITIVE. + """ + POLARITY_UNSPECIFIED = 0 + POSITIVE = 1 + NEGATIVE = 2 + BOTH = 3 + + class ColorMap(proto.Enum): + r"""The color scheme used for highlighting areas.""" + COLOR_MAP_UNSPECIFIED = 0 + PINK_GREEN = 1 + VIRIDIS = 2 + RED = 3 + GREEN = 4 + RED_GREEN = 6 + PINK_WHITE_GREEN = 5 + + class OverlayType(proto.Enum): + r"""How the original image is displayed in the visualization.""" + OVERLAY_TYPE_UNSPECIFIED = 0 + NONE = 1 + ORIGINAL = 2 + GRAYSCALE = 3 + MASK_BLACK = 4 + + type_ = proto.Field( + proto.ENUM, + number=1, + enum="ExplanationMetadata.InputMetadata.Visualization.Type", + ) + + polarity = proto.Field( + proto.ENUM, + number=2, + enum="ExplanationMetadata.InputMetadata.Visualization.Polarity", + ) + + color_map = proto.Field( + proto.ENUM, + number=3, + enum="ExplanationMetadata.InputMetadata.Visualization.ColorMap", + ) + + clip_percent_upperbound = proto.Field(proto.FLOAT, number=4) + + clip_percent_lowerbound = proto.Field(proto.FLOAT, number=5) + + overlay_type = proto.Field( + proto.ENUM, + number=6, + enum="ExplanationMetadata.InputMetadata.Visualization.OverlayType", + ) + input_baselines = proto.RepeatedField( proto.MESSAGE, number=1, message=struct.Value, ) + input_tensor_name = proto.Field(proto.STRING, number=2) + + encoding = proto.Field( + proto.ENUM, number=3, enum="ExplanationMetadata.InputMetadata.Encoding", + ) + + modality = proto.Field(proto.STRING, number=4) + + feature_value_domain = proto.Field( + proto.MESSAGE, + number=5, + message="ExplanationMetadata.InputMetadata.FeatureValueDomain", + ) + + indices_tensor_name = proto.Field(proto.STRING, number=6) + + dense_shape_tensor_name = proto.Field(proto.STRING, number=7) + + index_feature_mapping = proto.RepeatedField(proto.STRING, number=8) + + encoded_tensor_name = proto.Field(proto.STRING, number=9) + + encoded_baselines = proto.RepeatedField( + proto.MESSAGE, number=10, message=struct.Value, + ) + + visualization = proto.Field( + proto.MESSAGE, + number=11, + message="ExplanationMetadata.InputMetadata.Visualization", + ) + + group_name = proto.Field(proto.STRING, number=12) + class OutputMetadata(proto.Message): r"""Metadata of the prediction output to be explained. @@ -116,6 +384,10 @@ class OutputMetadata(proto.Message): of the outputs, so that it can be located by ``Attribution.output_index`` for a specific output. + output_tensor_name (str): + Name of the output tensor. Required and is + only applicable to AI Platform provided images + for Tensorflow. """ index_display_name_mapping = proto.Field( @@ -126,6 +398,8 @@ class OutputMetadata(proto.Message): proto.STRING, number=2, oneof="display_name_mapping" ) + output_tensor_name = proto.Field(proto.STRING, number=3) + inputs = proto.MapField( proto.STRING, proto.MESSAGE, number=1, message=InputMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index f713cd2f64..c71aca024e 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -31,6 +31,7 @@ "AutomaticResources", "BatchDedicatedResources", "ResourcesConsumed", + "DiskSpec", }, ) @@ -130,7 +131,7 @@ class DedicatedResources(proto.Message): as the default value. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) min_replica_count = proto.Field(proto.INT32, number=2) @@ -194,7 +195,7 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) starting_replica_count = proto.Field(proto.INT32, number=2) @@ -216,4 +217,23 @@ class ResourcesConsumed(proto.Message): replica_hours = proto.Field(proto.DOUBLE, number=1) +class DiskSpec(proto.Message): + r"""Represents the spec of disk options. + + Attributes: + boot_disk_type (str): + Type of the boot disk (default is "pd- + tandard"). Valid values: "pd-ssd" (Persistent + Disk Solid State Drive) or "pd-standard" + (Persistent Disk Hard Disk Drive). + boot_disk_size_gb (int): + Size in GB of the boot disk (default is + 100GB). + """ + + boot_disk_type = proto.Field(proto.STRING, number=1) + + boot_disk_size_gb = proto.Field(proto.INT32, number=2) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index 12b2150c35..68fb0daead 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -43,7 +43,9 @@ class GenericOperationMetadata(proto.Message): created. update_time (~.timestamp.Timestamp): Output only. Time when the operation was - updated for the last time. + updated for the last time. If the operation has + finished (successfully or not), this is the + finish time. """ partial_failures = proto.RepeatedField( @@ -64,7 +66,7 @@ class DeleteOperationMetadata(proto.Message): """ generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=GenericOperationMetadata, + proto.MESSAGE, number=1, message="GenericOperationMetadata", ) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index 8f8717d675..b000f88bf8 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -150,6 +150,10 @@ class ExplainResponse(proto.Message): deployed_model_id (str): ID of the Endpoint's DeployedModel that served this explanation. + predictions (Sequence[~.struct.Value]): + The predictions that are the output of the predictions call. + Same as + ``PredictResponse.predictions``. """ explanations = proto.RepeatedField( @@ -158,5 +162,7 @@ class ExplainResponse(proto.Message): deployed_model_id = proto.Field(proto.STRING, number=2) + predictions = proto.RepeatedField(proto.MESSAGE, number=3, message=struct.Value,) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 5d053e7162..2d6f4ae8c3 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -159,6 +159,12 @@ class ParameterSpec(proto.Message): scale_type (~.study.StudySpec.ParameterSpec.ScaleType): How the parameter should be scaled. Leave unset for ``CATEGORICAL`` parameters. + conditional_parameter_specs (Sequence[~.study.StudySpec.ParameterSpec.ConditionalParameterSpec]): + A conditional parameter node is active if the parameter's + value matches the conditional node's parent_value_condition. + + If two items in conditional_parameter_specs have the same + name, they must have disjoint parent_value_condition. """ class ScaleType(proto.Enum): @@ -225,6 +231,91 @@ class DiscreteValueSpec(proto.Message): values = proto.RepeatedField(proto.DOUBLE, number=1) + class ConditionalParameterSpec(proto.Message): + r"""Represents a parameter spec with condition from its parent + parameter. + + Attributes: + parent_discrete_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition): + The spec for matching values from a parent parameter of + ``DISCRETE`` type. + parent_int_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition): + The spec for matching values from a parent parameter of + ``INTEGER`` type. + parent_categorical_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition): + The spec for matching values from a parent parameter of + ``CATEGORICAL`` type. + parameter_spec (~.study.StudySpec.ParameterSpec): + Required. The spec for a conditional + parameter. + """ + + class DiscreteValueCondition(proto.Message): + r"""Represents the spec to match discrete values from parent + parameter. + + Attributes: + values (Sequence[float]): + Required. Matches values of the parent parameter of + 'DISCRETE' type. All values must exist in + ``discrete_value_spec`` of parent parameter. + + The Epsilon of the value matching is 1e-10. + """ + + values = proto.RepeatedField(proto.DOUBLE, number=1) + + class IntValueCondition(proto.Message): + r"""Represents the spec to match integer values from parent + parameter. + + Attributes: + values (Sequence[int]): + Required. Matches values of the parent parameter of + 'INTEGER' type. All values must lie in + ``integer_value_spec`` of parent parameter. + """ + + values = proto.RepeatedField(proto.INT64, number=1) + + class CategoricalValueCondition(proto.Message): + r"""Represents the spec to match categorical values from parent + parameter. + + Attributes: + values (Sequence[str]): + Required. Matches values of the parent parameter of + 'CATEGORICAL' type. All values must exist in + ``categorical_value_spec`` of parent parameter. + """ + + values = proto.RepeatedField(proto.STRING, number=1) + + parent_discrete_values = proto.Field( + proto.MESSAGE, + number=2, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition", + ) + + parent_int_values = proto.Field( + proto.MESSAGE, + number=3, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition", + ) + + parent_categorical_values = proto.Field( + proto.MESSAGE, + number=4, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition", + ) + + parameter_spec = proto.Field( + proto.MESSAGE, number=1, message="StudySpec.ParameterSpec", + ) + double_value_spec = proto.Field( proto.MESSAGE, number=2, @@ -259,6 +350,12 @@ class DiscreteValueSpec(proto.Message): proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", ) + conditional_parameter_specs = proto.RepeatedField( + proto.MESSAGE, + number=10, + message="StudySpec.ParameterSpec.ConditionalParameterSpec", + ) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index 86d6168b8e..f1f0debaf9 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -190,9 +190,9 @@ class InputDataConfig(proto.Message): Split based on the timestamp of the input data pieces. gcs_destination (~.io.GcsDestination): - The Google Cloud Storage location where the output is to be - written to. In the given directory a new directory will be - created with name: + The Google Cloud Storage location where the training data is + to be written to. In the given directory a new directory + will be created with name: ``dataset---`` where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. All training input data will be written into that @@ -203,18 +203,40 @@ class InputDataConfig(proto.Message): Google Cloud Storage wildcard format to support sharded data. e.g.: "gs://.../training-*.jsonl" - - AIP_DATA_FORMAT = "jsonl". + - AIP_DATA_FORMAT = "jsonl" for non-tabular data, "csv" for + tabular data - AIP_TRAINING_DATA_URI = - "gcs_destination/dataset---/training-*.jsonl" + "gcs_destination/dataset---/training-*.${AIP_DATA_FORMAT}" - AIP_VALIDATION_DATA_URI = - "gcs_destination/dataset---/validation-*.jsonl" + "gcs_destination/dataset---/validation-*.${AIP_DATA_FORMAT}" - AIP_TEST_DATA_URI = - "gcs_destination/dataset---/test-*.jsonl". + "gcs_destination/dataset---/test-*.${AIP_DATA_FORMAT}". + bigquery_destination (~.io.BigQueryDestination): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI = + + "bigquery_destination.dataset\_\ **\ .training" + + - AIP_VALIDATION_DATA_URI = + + "bigquery_destination.dataset\_\ **\ .validation" + + - AIP_TEST_DATA_URI = + "bigquery_destination.dataset\_\ **\ .test". dataset_id (str): Required. The ID of the Dataset in the same Project and Location which data will be used to train the Model. The @@ -284,6 +306,10 @@ class InputDataConfig(proto.Message): proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, ) + bigquery_destination = proto.Field( + proto.MESSAGE, number=10, oneof="destination", message=io.BigQueryDestination, + ) + dataset_id = proto.Field(proto.STRING, number=1) annotations_filter = proto.Field(proto.STRING, number=6) diff --git a/scripts/fixup_aiplatform_v1beta1_keywords.py b/scripts/fixup_aiplatform_v1beta1_keywords.py index 7188a7d5bc..4842dae628 100644 --- a/scripts/fixup_aiplatform_v1beta1_keywords.py +++ b/scripts/fixup_aiplatform_v1beta1_keywords.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/synth.metadata b/synth.metadata index ec41255fd8..866be8e22e 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,14 +4,14 @@ "git": { "name": ".", "remote": "https://github.com/dizcology/python-aiplatform.git", - "sha": "288035dd0612b35204273d09a2b3dbbba9fe5e2c" + "sha": "b428c3bd3c19861cb431595f71aa43123e0dd1af" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "77c5ba85e05950f5b19ce8a553c1c0db2fba9896" + "sha": "f68649c5f26bcff6817c6d21e90dac0fc71fef8e" } } ], diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index 8b4313034b..08020beb3c 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceClient from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.services.dataset_service import transports from google.cloud.aiplatform_v1beta1.types import annotation @@ -66,11 +62,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -81,35 +73,17 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] -) +@pytest.mark.parametrize("client_class", [DatasetServiceClient, DatasetServiceAsyncClient]) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -117,7 +91,7 @@ def test_dataset_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_dataset_service_client_get_transport_class(): @@ -128,44 +102,29 @@ def test_dataset_service_client_get_transport_class(): assert transport == transports.DatasetServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) -def test_dataset_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) +def test_dataset_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -181,7 +140,7 @@ def test_dataset_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -197,7 +156,7 @@ def test_dataset_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -217,15 +176,13 @@ def test_dataset_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -238,56 +195,26 @@ def test_dataset_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) +@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): patched.return_value = None client = client_class(client_options=options) @@ -310,21 +237,11 @@ def test_dataset_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -334,9 +251,7 @@ def test_dataset_service_client_mtls_env_auto( is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + expected_ssl_channel_creds = ssl_credentials_mock.return_value patched.return_value = None client = client_class() @@ -351,17 +266,10 @@ def test_dataset_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -376,23 +284,16 @@ def test_dataset_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -405,24 +306,16 @@ def test_dataset_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -437,12 +330,10 @@ def test_dataset_service_client_client_options_credentials_file( def test_dataset_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = DatasetServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -455,11 +346,10 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset( - transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest -): +def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -467,9 +357,11 @@ def test_create_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_dataset(request) @@ -488,20 +380,23 @@ def test_create_dataset_from_dict(): @pytest.mark.asyncio -async def test_create_dataset_async(transport: str = "grpc_asyncio"): +async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.CreateDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.CreateDatasetRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_dataset(request) @@ -510,23 +405,32 @@ async def test_create_dataset_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.CreateDatasetRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_create_dataset_async_from_dict(): + await test_create_dataset_async(request_type=dict) + + def test_create_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_dataset(request) @@ -537,23 +441,28 @@ def test_create_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_dataset(request) @@ -564,21 +473,29 @@ async def test_create_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -586,40 +503,47 @@ def test_create_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') def test_create_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_dataset( dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) @pytest.mark.asyncio async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.create_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -627,30 +551,31 @@ async def test_create_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') @pytest.mark.asyncio async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_dataset( dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), + parent='parent_value', + dataset=gca_dataset.Dataset(name='name_value'), ) -def test_get_dataset( - transport: str = "grpc", request_type=dataset_service.GetDatasetRequest -): +def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -658,13 +583,19 @@ def test_get_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + ) response = client.get_dataset(request) @@ -676,15 +607,16 @@ def test_get_dataset( assert args[0] == dataset_service.GetDatasetRequest() # Establish that the response is the type that we expect. + assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_dataset_from_dict(): @@ -692,26 +624,27 @@ def test_get_dataset_from_dict(): @pytest.mark.asyncio -async def test_get_dataset_async(transport: str = "grpc_asyncio"): +async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetDatasetRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) response = await client.get_dataset(request) @@ -719,30 +652,39 @@ async def test_get_dataset_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.GetDatasetRequest() # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' - assert response.display_name == "display_name_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.etag == 'etag_value' - assert response.etag == "etag_value" + +@pytest.mark.asyncio +async def test_get_dataset_async_from_dict(): + await test_get_dataset_async(request_type=dict) def test_get_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: call.return_value = dataset.Dataset() client.get_dataset(request) @@ -754,20 +696,27 @@ def test_get_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) await client.get_dataset(request) @@ -779,79 +728,99 @@ async def test_get_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset(name="name_value",) + client.get_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", + dataset_service.GetDatasetRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.get_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset(name="name_value",) + response = await client.get_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", + dataset_service.GetDatasetRequest(), + name='name_value', ) -def test_update_dataset( - transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest -): +def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -859,13 +828,19 @@ def test_update_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + metadata_schema_uri='metadata_schema_uri_value', + + etag='etag_value', + ) response = client.update_dataset(request) @@ -877,15 +852,16 @@ def test_update_dataset( assert args[0] == dataset_service.UpdateDatasetRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_dataset_from_dict(): @@ -893,26 +869,27 @@ def test_update_dataset_from_dict(): @pytest.mark.asyncio -async def test_update_dataset_async(transport: str = "grpc_asyncio"): +async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.UpdateDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.UpdateDatasetRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( + name='name_value', + display_name='display_name_value', + metadata_schema_uri='metadata_schema_uri_value', + etag='etag_value', + )) response = await client.update_dataset(request) @@ -920,30 +897,39 @@ async def test_update_dataset_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.UpdateDatasetRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_update_dataset_async_from_dict(): + await test_update_dataset_async(request_type=dict) def test_update_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" + request.dataset.name = 'dataset.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: call.return_value = gca_dataset.Dataset() client.update_dataset(request) @@ -955,22 +941,27 @@ def test_update_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" + request.dataset.name = 'dataset.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) await client.update_dataset(request) @@ -982,24 +973,29 @@ async def test_update_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'dataset.name=dataset.name/value', + ) in kw['metadata'] def test_update_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1007,30 +1003,36 @@ def test_update_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.update_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() @@ -1038,8 +1040,8 @@ async def test_update_dataset_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1047,30 +1049,31 @@ async def test_update_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].dataset == gca_dataset.Dataset(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + dataset=gca_dataset.Dataset(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_list_datasets( - transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest -): +def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1078,10 +1081,13 @@ def test_list_datasets( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_datasets(request) @@ -1093,9 +1099,10 @@ def test_list_datasets( assert args[0] == dataset_service.ListDatasetsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_datasets_from_dict(): @@ -1103,23 +1110,24 @@ def test_list_datasets_from_dict(): @pytest.mark.asyncio -async def test_list_datasets_async(transport: str = "grpc_asyncio"): +async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDatasetsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDatasetsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_datasets(request) @@ -1127,24 +1135,33 @@ async def test_list_datasets_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.ListDatasetsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_datasets_async_from_dict(): + await test_list_datasets_async(request_type=dict) def test_list_datasets_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: call.return_value = dataset_service.ListDatasetsResponse() client.list_datasets(request) @@ -1156,23 +1173,28 @@ def test_list_datasets_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) await client.list_datasets(request) @@ -1183,100 +1205,138 @@ async def test_list_datasets_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_datasets_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_datasets(parent="parent_value",) + client.list_datasets( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_datasets_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", + dataset_service.ListDatasetsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_datasets(parent="parent_value",) + response = await client.list_datasets( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", + dataset_service.ListDatasetsRequest(), + parent='parent_value', ) def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_datasets(request={}) @@ -1284,102 +1344,147 @@ def test_list_datasets_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) for i in results) - + assert all(isinstance(i, dataset.Dataset) + for i in results) def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + with mock.patch.object( + type(client.transport.list_datasets), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[], + next_page_token='def', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', + ), + dataset_service.ListDatasetsResponse( + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) for i in responses) - + assert all(isinstance(i, dataset.Dataset) + for i in responses) @pytest.mark.asyncio async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_datasets), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + dataset.Dataset(), + ], + next_page_token='abc', + ), + dataset_service.ListDatasetsResponse( + datasets=[], + next_page_token='def', ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", + datasets=[ + dataset.Dataset(), + ], + next_page_token='ghi', ), dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], + datasets=[ + dataset.Dataset(), + dataset.Dataset(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_datasets(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_dataset( - transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest -): +def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1387,9 +1492,11 @@ def test_delete_dataset( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_dataset(request) @@ -1408,20 +1515,23 @@ def test_delete_dataset_from_dict(): @pytest.mark.asyncio -async def test_delete_dataset_async(transport: str = "grpc_asyncio"): +async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.DeleteDatasetRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.DeleteDatasetRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_dataset(request) @@ -1430,23 +1540,32 @@ async def test_delete_dataset_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.DeleteDatasetRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_dataset_async_from_dict(): + await test_delete_dataset_async(request_type=dict) + + def test_delete_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_dataset(request) @@ -1457,23 +1576,28 @@ def test_delete_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_dataset(request) @@ -1484,81 +1608,101 @@ async def test_delete_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset(name="name_value",) + client.delete_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", + dataset_service.DeleteDatasetRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_dataset), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset(name="name_value",) + response = await client.delete_dataset( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", + dataset_service.DeleteDatasetRequest(), + name='name_value', ) -def test_import_data( - transport: str = "grpc", request_type=dataset_service.ImportDataRequest -): +def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1566,9 +1710,11 @@ def test_import_data( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.import_data(request) @@ -1587,20 +1733,23 @@ def test_import_data_from_dict(): @pytest.mark.asyncio -async def test_import_data_async(transport: str = "grpc_asyncio"): +async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ImportDataRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ImportDataRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.import_data(request) @@ -1609,23 +1758,32 @@ async def test_import_data_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.ImportDataRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_import_data_async_from_dict(): + await test_import_data_async(request_type=dict) + + def test_import_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.import_data(request) @@ -1636,23 +1794,28 @@ def test_import_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.import_data(request) @@ -1663,24 +1826,29 @@ async def test_import_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_import_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) # Establish that the underlying call was made with the expected @@ -1688,47 +1856,47 @@ def test_import_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] def test_import_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_data( dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) @pytest.mark.asyncio async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.import_data), "__call__") as call: + with mock.patch.object( + type(client.transport.import_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) # Establish that the underlying call was made with the expected @@ -1736,34 +1904,31 @@ async def test_import_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] + assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] @pytest.mark.asyncio async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.import_data( dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], + name='name_value', + import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], ) -def test_export_data( - transport: str = "grpc", request_type=dataset_service.ExportDataRequest -): +def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1771,9 +1936,11 @@ def test_export_data( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.export_data(request) @@ -1792,20 +1959,23 @@ def test_export_data_from_dict(): @pytest.mark.asyncio -async def test_export_data_async(transport: str = "grpc_asyncio"): +async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ExportDataRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ExportDataRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.export_data(request) @@ -1814,23 +1984,32 @@ async def test_export_data_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.ExportDataRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_export_data_async_from_dict(): + await test_export_data_async(request_type=dict) + + def test_export_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.export_data(request) @@ -1841,23 +2020,28 @@ def test_export_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.export_data(request) @@ -1868,26 +2052,29 @@ async def test_export_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_export_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) # Establish that the underlying call was made with the expected @@ -1895,53 +2082,47 @@ def test_export_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) def test_export_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_data( dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) @pytest.mark.asyncio async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_data), "__call__") as call: + with mock.patch.object( + type(client.transport.export_data), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) # Establish that the underlying call was made with the expected @@ -1949,38 +2130,31 @@ async def test_export_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) + assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) @pytest.mark.asyncio async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_data( dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), + name='name_value', + export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), ) -def test_list_data_items( - transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest -): +def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1988,10 +2162,13 @@ def test_list_data_items( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_data_items(request) @@ -2003,9 +2180,10 @@ def test_list_data_items( assert args[0] == dataset_service.ListDataItemsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_data_items_from_dict(): @@ -2013,23 +2191,24 @@ def test_list_data_items_from_dict(): @pytest.mark.asyncio -async def test_list_data_items_async(transport: str = "grpc_asyncio"): +async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDataItemsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDataItemsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_data_items(request) @@ -2037,24 +2216,33 @@ async def test_list_data_items_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.ListDataItemsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_data_items_async_from_dict(): + await test_list_data_items_async(request_type=dict) def test_list_data_items_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: call.return_value = dataset_service.ListDataItemsResponse() client.list_data_items(request) @@ -2066,23 +2254,28 @@ def test_list_data_items_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) await client.list_data_items(request) @@ -2093,81 +2286,104 @@ async def test_list_data_items_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_data_items_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items(parent="parent_value",) + client.list_data_items( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_data_items_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", + dataset_service.ListDataItemsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items(parent="parent_value",) + response = await client.list_data_items( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", + dataset_service.ListDataItemsRequest(), + parent='parent_value', ) def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2176,23 +2392,32 @@ def test_list_data_items_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_data_items(request={}) @@ -2200,14 +2425,18 @@ def test_list_data_items_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) for i in results) - + assert all(isinstance(i, data_item.DataItem) + for i in results) def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + with mock.patch.object( + type(client.transport.list_data_items), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2216,32 +2445,40 @@ def test_list_data_items_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2250,37 +2487,46 @@ async def test_list_data_items_async_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) for i in responses) - + assert all(isinstance(i, data_item.DataItem) + for i in responses) @pytest.mark.asyncio async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_data_items), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2289,31 +2535,37 @@ async def test_list_data_items_async_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", + data_items=[], + next_page_token='def', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", + data_items=[ + data_item.DataItem(), + ], + next_page_token='ghi', ), dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_data_items(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec( - transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest -): +def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2322,11 +2574,16 @@ def test_get_annotation_spec( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", + name='name_value', + + display_name='display_name_value', + + etag='etag_value', + ) response = client.get_annotation_spec(request) @@ -2338,13 +2595,14 @@ def test_get_annotation_spec( assert args[0] == dataset_service.GetAnnotationSpecRequest() # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_annotation_spec_from_dict(): @@ -2352,25 +2610,26 @@ def test_get_annotation_spec_from_dict(): @pytest.mark.asyncio -async def test_get_annotation_spec_async(transport: str = "grpc_asyncio"): +async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetAnnotationSpecRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetAnnotationSpecRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( + name='name_value', + display_name='display_name_value', + etag='etag_value', + )) response = await client.get_annotation_spec(request) @@ -2378,30 +2637,37 @@ async def test_get_annotation_spec_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.GetAnnotationSpecRequest() # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' - assert response.display_name == "display_name_value" + assert response.etag == 'etag_value' - assert response.etag == "etag_value" + +@pytest.mark.asyncio +async def test_get_annotation_spec_async_from_dict(): + await test_get_annotation_spec_async(request_type=dict) def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: call.return_value = annotation_spec.AnnotationSpec() client.get_annotation_spec(request) @@ -2413,25 +2679,28 @@ def test_get_annotation_spec_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) + type(client.transport.get_annotation_spec), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) await client.get_annotation_spec(request) @@ -2442,85 +2711,99 @@ async def test_get_annotation_spec_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_annotation_spec_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_annotation_spec(name="name_value",) + client.get_annotation_spec( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", + dataset_service.GetAnnotationSpecRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), "__call__" - ) as call: + type(client.transport.get_annotation_spec), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_annotation_spec(name="name_value",) + response = await client.get_annotation_spec( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", + dataset_service.GetAnnotationSpecRequest(), + name='name_value', ) -def test_list_annotations( - transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest -): +def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2528,10 +2811,13 @@ def test_list_annotations( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_annotations(request) @@ -2543,9 +2829,10 @@ def test_list_annotations( assert args[0] == dataset_service.ListAnnotationsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_annotations_from_dict(): @@ -2553,23 +2840,24 @@ def test_list_annotations_from_dict(): @pytest.mark.asyncio -async def test_list_annotations_async(transport: str = "grpc_asyncio"): +async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListAnnotationsRequest): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListAnnotationsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_annotations(request) @@ -2577,24 +2865,33 @@ async def test_list_annotations_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == dataset_service.ListAnnotationsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_annotations_async_from_dict(): + await test_list_annotations_async(request_type=dict) def test_list_annotations_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: call.return_value = dataset_service.ListAnnotationsResponse() client.list_annotations(request) @@ -2606,23 +2903,28 @@ def test_list_annotations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) await client.list_annotations(request) @@ -2633,81 +2935,104 @@ async def test_list_annotations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_annotations_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations(parent="parent_value",) + client.list_annotations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_annotations_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", + dataset_service.ListAnnotationsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations(parent="parent_value",) + response = await client.list_annotations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", + dataset_service.ListAnnotationsRequest(), + parent='parent_value', ) def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2716,23 +3041,32 @@ def test_list_annotations_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_annotations(request={}) @@ -2740,14 +3074,18 @@ def test_list_annotations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) for i in results) - + assert all(isinstance(i, annotation.Annotation) + for i in results) def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + with mock.patch.object( + type(client.transport.list_annotations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2756,32 +3094,40 @@ def test_list_annotations_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2790,37 +3136,46 @@ async def test_list_annotations_async_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) for i in responses) - + assert all(isinstance(i, annotation.Annotation) + for i in responses) @pytest.mark.asyncio async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_annotations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -2829,23 +3184,30 @@ async def test_list_annotations_async_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token="abc", + next_page_token='abc', ), dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", + annotations=[], + next_page_token='def', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", + annotations=[ + annotation.Annotation(), + ], + next_page_token='ghi', ), dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_annotations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token @@ -2856,7 +3218,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2875,7 +3238,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -2903,16 +3267,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2920,8 +3281,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.DatasetServiceGrpcTransport, + ) def test_dataset_service_base_transport_error(): @@ -2929,15 +3295,13 @@ def test_dataset_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_dataset_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2946,17 +3310,17 @@ def test_dataset_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_dataset", - "get_dataset", - "update_dataset", - "list_datasets", - "delete_dataset", - "import_data", - "export_data", - "list_data_items", - "get_annotation_spec", - "list_annotations", - ) + 'create_dataset', + 'get_dataset', + 'update_dataset', + 'list_datasets', + 'delete_dataset', + 'import_data', + 'export_data', + 'list_data_items', + 'get_annotation_spec', + 'list_annotations', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2969,28 +3333,23 @@ def test_dataset_service_base_transport(): def test_dataset_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_dataset_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport() @@ -2999,11 +3358,11 @@ def test_dataset_service_base_transport_with_adc(): def test_dataset_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) DatasetServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -3011,75 +3370,62 @@ def test_dataset_service_auth_adc(): def test_dataset_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) - def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_dataset_service_host_with_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_dataset_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3088,7 +3434,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3104,30 +3450,27 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_dataset_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +def test_dataset_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3144,7 +3487,9 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3153,12 +3498,16 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): def test_dataset_service_grpc_lro_client(): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3166,17 +3515,20 @@ def test_dataset_service_grpc_lro_client(): def test_dataset_service_grpc_lro_async_client(): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client - def test_annotation_path(): project = "squid" location = "clam" @@ -3184,26 +3536,19 @@ def test_annotation_path(): data_item = "octopus" annotation = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( - project=project, - location=location, - dataset=dataset, - data_item=data_item, - annotation=annotation, - ) - actual = DatasetServiceClient.annotation_path( - project, location, dataset, data_item, annotation - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) + actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) assert expected == actual def test_parse_annotation_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - "data_item": "winkle", - "annotation": "nautilus", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", + } path = DatasetServiceClient.annotation_path(**expected) @@ -3211,31 +3556,24 @@ def test_parse_annotation_path(): actual = DatasetServiceClient.parse_annotation_path(path) assert expected == actual - def test_annotation_spec_path(): project = "scallop" location = "abalone" dataset = "squid" annotation_spec = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( - project=project, - location=location, - dataset=dataset, - annotation_spec=annotation_spec, - ) - actual = DatasetServiceClient.annotation_spec_path( - project, location, dataset, annotation_spec - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) + actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) assert expected == actual def test_parse_annotation_spec_path(): expected = { - "project": "whelk", - "location": "octopus", - "dataset": "oyster", - "annotation_spec": "nudibranch", + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", + } path = DatasetServiceClient.annotation_spec_path(**expected) @@ -3243,26 +3581,24 @@ def test_parse_annotation_spec_path(): actual = DatasetServiceClient.parse_annotation_spec_path(path) assert expected == actual - def test_data_item_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" data_item = "nautilus" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( - project=project, location=location, dataset=dataset, data_item=data_item, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) assert expected == actual def test_parse_data_item_path(): expected = { - "project": "scallop", - "location": "abalone", - "dataset": "squid", - "data_item": "clam", + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", + } path = DatasetServiceClient.data_item_path(**expected) @@ -3270,24 +3606,22 @@ def test_parse_data_item_path(): actual = DatasetServiceClient.parse_data_item_path(path) assert expected == actual - def test_dataset_path(): project = "whelk" location = "octopus" dataset = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = DatasetServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + } path = DatasetServiceClient.dataset_path(**expected) @@ -3295,20 +3629,18 @@ def test_parse_dataset_path(): actual = DatasetServiceClient.parse_dataset_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = DatasetServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", + "billing_account": "nautilus", + } path = DatasetServiceClient.common_billing_account_path(**expected) @@ -3316,18 +3648,18 @@ def test_parse_common_billing_account_path(): actual = DatasetServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = DatasetServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", + "folder": "abalone", + } path = DatasetServiceClient.common_folder_path(**expected) @@ -3335,18 +3667,18 @@ def test_parse_common_folder_path(): actual = DatasetServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = DatasetServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", + "organization": "clam", + } path = DatasetServiceClient.common_organization_path(**expected) @@ -3354,18 +3686,18 @@ def test_parse_common_organization_path(): actual = DatasetServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = DatasetServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", + "project": "octopus", + } path = DatasetServiceClient.common_project_path(**expected) @@ -3373,22 +3705,20 @@ def test_parse_common_project_path(): actual = DatasetServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = DatasetServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", + "project": "cuttlefish", + "location": "mussel", + } path = DatasetServiceClient.common_location_path(**expected) @@ -3400,19 +3730,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: transport_class = DatasetServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index 45895347ec..29daf6cff8 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceClient from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type @@ -66,11 +62,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -81,35 +73,17 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] -) +@pytest.mark.parametrize("client_class", [EndpointServiceClient, EndpointServiceAsyncClient]) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -117,7 +91,7 @@ def test_endpoint_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_endpoint_service_client_get_transport_class(): @@ -128,44 +102,29 @@ def test_endpoint_service_client_get_transport_class(): assert transport == transports.EndpointServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) -def test_endpoint_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) +def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -181,7 +140,7 @@ def test_endpoint_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -197,7 +156,7 @@ def test_endpoint_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -217,15 +176,13 @@ def test_endpoint_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -238,66 +195,26 @@ def test_endpoint_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "true", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "false", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) +@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): patched.return_value = None client = client_class(client_options=options) @@ -320,21 +237,11 @@ def test_endpoint_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -344,9 +251,7 @@ def test_endpoint_service_client_mtls_env_auto( is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + expected_ssl_channel_creds = ssl_credentials_mock.return_value patched.return_value = None client = client_class() @@ -361,17 +266,10 @@ def test_endpoint_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -386,23 +284,16 @@ def test_endpoint_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -415,24 +306,16 @@ def test_endpoint_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -447,12 +330,10 @@ def test_endpoint_service_client_client_options_credentials_file( def test_endpoint_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = EndpointServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -465,11 +346,10 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint( - transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest -): +def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -477,9 +357,11 @@ def test_create_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_endpoint(request) @@ -498,20 +380,23 @@ def test_create_endpoint_from_dict(): @pytest.mark.asyncio -async def test_create_endpoint_async(transport: str = "grpc_asyncio"): +async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.CreateEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.CreateEndpointRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_endpoint(request) @@ -520,23 +405,32 @@ async def test_create_endpoint_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == endpoint_service.CreateEndpointRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_create_endpoint_async_from_dict(): + await test_create_endpoint_async(request_type=dict) + + def test_create_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_endpoint(request) @@ -547,23 +441,28 @@ def test_create_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_endpoint(request) @@ -574,21 +473,29 @@ async def test_create_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -596,40 +503,47 @@ def test_create_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') def test_create_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) @pytest.mark.asyncio async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.create_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -637,30 +551,31 @@ async def test_create_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') @pytest.mark.asyncio async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), + parent='parent_value', + endpoint=gca_endpoint.Endpoint(name='name_value'), ) -def test_get_endpoint( - transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest -): +def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -668,13 +583,19 @@ def test_get_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + ) response = client.get_endpoint(request) @@ -686,15 +607,16 @@ def test_get_endpoint( assert args[0] == endpoint_service.GetEndpointRequest() # Establish that the response is the type that we expect. + assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_endpoint_from_dict(): @@ -702,26 +624,27 @@ def test_get_endpoint_from_dict(): @pytest.mark.asyncio -async def test_get_endpoint_async(transport: str = "grpc_asyncio"): +async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.GetEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.GetEndpointRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) response = await client.get_endpoint(request) @@ -729,30 +652,39 @@ async def test_get_endpoint_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == endpoint_service.GetEndpointRequest() # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' - assert response.display_name == "display_name_value" + assert response.description == 'description_value' - assert response.description == "description_value" + assert response.etag == 'etag_value' - assert response.etag == "etag_value" + +@pytest.mark.asyncio +async def test_get_endpoint_async_from_dict(): + await test_get_endpoint_async(request_type=dict) def test_get_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: call.return_value = endpoint.Endpoint() client.get_endpoint(request) @@ -764,20 +696,27 @@ def test_get_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) await client.get_endpoint(request) @@ -789,79 +728,99 @@ async def test_get_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_endpoint(name="name_value",) + client.get_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", + endpoint_service.GetEndpointRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.get_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_endpoint(name="name_value",) + response = await client.get_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", + endpoint_service.GetEndpointRequest(), + name='name_value', ) -def test_list_endpoints( - transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest -): +def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -869,10 +828,13 @@ def test_list_endpoints( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_endpoints(request) @@ -884,9 +846,10 @@ def test_list_endpoints( assert args[0] == endpoint_service.ListEndpointsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_endpoints_from_dict(): @@ -894,23 +857,24 @@ def test_list_endpoints_from_dict(): @pytest.mark.asyncio -async def test_list_endpoints_async(transport: str = "grpc_asyncio"): +async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.ListEndpointsRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.ListEndpointsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_endpoints(request) @@ -918,24 +882,33 @@ async def test_list_endpoints_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == endpoint_service.ListEndpointsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_endpoints_async_from_dict(): + await test_list_endpoints_async(request_type=dict) def test_list_endpoints_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: call.return_value = endpoint_service.ListEndpointsResponse() client.list_endpoints(request) @@ -947,23 +920,28 @@ def test_list_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) await client.list_endpoints(request) @@ -974,81 +952,104 @@ async def test_list_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_endpoints_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_endpoints(parent="parent_value",) + client.list_endpoints( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_endpoints_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", + endpoint_service.ListEndpointsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_endpoints(parent="parent_value",) + response = await client.list_endpoints( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", + endpoint_service.ListEndpointsRequest(), + parent='parent_value', ) def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1057,23 +1058,32 @@ def test_list_endpoints_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_endpoints(request={}) @@ -1081,14 +1091,18 @@ def test_list_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in results) - + assert all(isinstance(i, endpoint.Endpoint) + for i in results) def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + with mock.patch.object( + type(client.transport.list_endpoints), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1097,32 +1111,40 @@ def test_list_endpoints_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1131,37 +1153,46 @@ async def test_list_endpoints_async_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in responses) - + assert all(isinstance(i, endpoint.Endpoint) + for i in responses) @pytest.mark.asyncio async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_endpoints), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1170,31 +1201,37 @@ async def test_list_endpoints_async_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token="abc", + next_page_token='abc', ), endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", + endpoints=[], + next_page_token='def', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + endpoints=[ + endpoint.Endpoint(), + ], + next_page_token='ghi', ), endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_update_endpoint( - transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest -): +def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1202,13 +1239,19 @@ def test_update_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + etag='etag_value', + ) response = client.update_endpoint(request) @@ -1220,15 +1263,16 @@ def test_update_endpoint( assert args[0] == endpoint_service.UpdateEndpointRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_endpoint_from_dict(): @@ -1236,26 +1280,27 @@ def test_update_endpoint_from_dict(): @pytest.mark.asyncio -async def test_update_endpoint_async(transport: str = "grpc_asyncio"): +async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UpdateEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UpdateEndpointRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( + name='name_value', + display_name='display_name_value', + description='description_value', + etag='etag_value', + )) response = await client.update_endpoint(request) @@ -1263,30 +1308,39 @@ async def test_update_endpoint_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == endpoint_service.UpdateEndpointRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.etag == "etag_value" + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_update_endpoint_async_from_dict(): + await test_update_endpoint_async(request_type=dict) def test_update_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" + request.endpoint.name = 'endpoint.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: call.return_value = gca_endpoint.Endpoint() client.update_endpoint(request) @@ -1298,25 +1352,28 @@ def test_update_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" + request.endpoint.name = 'endpoint.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) await client.update_endpoint(request) @@ -1327,24 +1384,29 @@ async def test_update_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] + assert ( + 'x-goog-request-params', + 'endpoint.name=endpoint.name/value', + ) in kw['metadata'] def test_update_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1352,41 +1414,45 @@ def test_update_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.update_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1394,30 +1460,31 @@ async def test_update_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + endpoint=gca_endpoint.Endpoint(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_delete_endpoint( - transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest -): +def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1425,9 +1492,11 @@ def test_delete_endpoint( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_endpoint(request) @@ -1446,20 +1515,23 @@ def test_delete_endpoint_from_dict(): @pytest.mark.asyncio -async def test_delete_endpoint_async(transport: str = "grpc_asyncio"): +async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeleteEndpointRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeleteEndpointRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_endpoint(request) @@ -1468,23 +1540,32 @@ async def test_delete_endpoint_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == endpoint_service.DeleteEndpointRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_endpoint_async_from_dict(): + await test_delete_endpoint_async(request_type=dict) + + def test_delete_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_endpoint(request) @@ -1495,23 +1576,28 @@ def test_delete_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_endpoint(request) @@ -1522,81 +1608,101 @@ async def test_delete_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_endpoint(name="name_value",) + client.delete_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", + endpoint_service.DeleteEndpointRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_endpoint), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_endpoint(name="name_value",) + response = await client.delete_endpoint( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", + endpoint_service.DeleteEndpointRequest(), + name='name_value', ) -def test_deploy_model( - transport: str = "grpc", request_type=endpoint_service.DeployModelRequest -): +def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1604,9 +1710,11 @@ def test_deploy_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.deploy_model(request) @@ -1625,20 +1733,23 @@ def test_deploy_model_from_dict(): @pytest.mark.asyncio -async def test_deploy_model_async(transport: str = "grpc_asyncio"): +async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeployModelRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeployModelRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.deploy_model(request) @@ -1647,23 +1758,32 @@ async def test_deploy_model_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == endpoint_service.DeployModelRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_deploy_model_async_from_dict(): + await test_deploy_model_async(request_type=dict) + + def test_deploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.deploy_model(request) @@ -1674,23 +1794,28 @@ def test_deploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.deploy_model(request) @@ -1701,29 +1826,30 @@ async def test_deploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] def test_deploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -1731,63 +1857,51 @@ def test_deploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} def test_deploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) @pytest.mark.asyncio async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.deploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -1795,45 +1909,34 @@ async def test_deploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) + assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} @pytest.mark.asyncio async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), + traffic_split={'key_value': 541}, ) -def test_undeploy_model( - transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest -): +def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1841,9 +1944,11 @@ def test_undeploy_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.undeploy_model(request) @@ -1862,20 +1967,23 @@ def test_undeploy_model_from_dict(): @pytest.mark.asyncio -async def test_undeploy_model_async(transport: str = "grpc_asyncio"): +async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UndeployModelRequest): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UndeployModelRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.undeploy_model(request) @@ -1884,23 +1992,32 @@ async def test_undeploy_model_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == endpoint_service.UndeployModelRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_undeploy_model_async_from_dict(): + await test_undeploy_model_async(request_type=dict) + + def test_undeploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.undeploy_model(request) @@ -1911,23 +2028,28 @@ def test_undeploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" + request.endpoint = 'endpoint/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.undeploy_model(request) @@ -1938,23 +2060,30 @@ async def test_undeploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'endpoint=endpoint/value', + ) in kw['metadata'] def test_undeploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -1962,45 +2091,51 @@ def test_undeploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model_id == "deployed_model_id_value" + assert args[0].deployed_model_id == 'deployed_model_id_value' - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} def test_undeploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) @pytest.mark.asyncio async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + with mock.patch.object( + type(client.transport.undeploy_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) # Establish that the underlying call was made with the expected @@ -2008,25 +2143,27 @@ async def test_undeploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" + assert args[0].endpoint == 'endpoint_value' - assert args[0].deployed_model_id == "deployed_model_id_value" + assert args[0].deployed_model_id == 'deployed_model_id_value' - assert args[0].traffic_split == {"key_value": 541} + assert args[0].traffic_split == {'key_value': 541} @pytest.mark.asyncio async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, + endpoint='endpoint_value', + deployed_model_id='deployed_model_id_value', + traffic_split={'key_value': 541}, ) @@ -2037,7 +2174,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2056,7 +2194,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -2084,16 +2223,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2101,8 +2237,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.EndpointServiceGrpcTransport, + ) def test_endpoint_service_base_transport_error(): @@ -2110,15 +2251,13 @@ def test_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2127,14 +2266,14 @@ def test_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_endpoint", - "get_endpoint", - "list_endpoints", - "update_endpoint", - "delete_endpoint", - "deploy_model", - "undeploy_model", - ) + 'create_endpoint', + 'get_endpoint', + 'list_endpoints', + 'update_endpoint', + 'delete_endpoint', + 'deploy_model', + 'undeploy_model', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2147,28 +2286,23 @@ def test_endpoint_service_base_transport(): def test_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport() @@ -2177,11 +2311,11 @@ def test_endpoint_service_base_transport_with_adc(): def test_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) EndpointServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -2189,75 +2323,62 @@ def test_endpoint_service_auth_adc(): def test_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) - def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_endpoint_service_host_with_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_endpoint_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2266,7 +2387,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2282,30 +2403,27 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +def test_endpoint_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2322,7 +2440,9 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -2331,12 +2451,16 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): def test_endpoint_service_grpc_lro_client(): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2344,34 +2468,36 @@ def test_endpoint_service_grpc_lro_client(): def test_endpoint_service_grpc_lro_async_client(): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client - def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = EndpointServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = EndpointServiceClient.endpoint_path(**expected) @@ -2379,24 +2505,22 @@ def test_parse_endpoint_path(): actual = EndpointServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = EndpointServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = EndpointServiceClient.model_path(**expected) @@ -2404,20 +2528,18 @@ def test_parse_model_path(): actual = EndpointServiceClient.parse_model_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = EndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "clam", + } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2425,18 +2547,18 @@ def test_parse_common_billing_account_path(): actual = EndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "octopus", + } path = EndpointServiceClient.common_folder_path(**expected) @@ -2444,18 +2566,18 @@ def test_parse_common_folder_path(): actual = EndpointServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nudibranch", + } path = EndpointServiceClient.common_organization_path(**expected) @@ -2463,18 +2585,18 @@ def test_parse_common_organization_path(): actual = EndpointServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = EndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "mussel", + } path = EndpointServiceClient.common_project_path(**expected) @@ -2482,22 +2604,20 @@ def test_parse_common_project_path(): actual = EndpointServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = EndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "scallop", + "location": "abalone", + } path = EndpointServiceClient.common_location_path(**expected) @@ -2509,19 +2629,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: transport_class = EndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index 19a9fe139c..a5543f7767 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -41,20 +41,14 @@ from google.cloud.aiplatform_v1beta1.services.job_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) +from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state @@ -81,11 +75,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -96,30 +86,17 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert ( - JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) + assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi @pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -127,7 +104,7 @@ def test_job_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_job_service_client_get_transport_class(): @@ -138,42 +115,29 @@ def test_job_service_client_get_transport_class(): assert transport == transports.JobServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) -def test_job_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) +def test_job_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -189,7 +153,7 @@ def test_job_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -205,7 +169,7 @@ def test_job_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -225,15 +189,13 @@ def test_job_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -246,54 +208,26 @@ def test_job_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) +@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): patched.return_value = None client = client_class(client_options=options) @@ -316,21 +250,11 @@ def test_job_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -340,9 +264,7 @@ def test_job_service_client_mtls_env_auto( is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + expected_ssl_channel_creds = ssl_credentials_mock.return_value patched.return_value = None client = client_class() @@ -357,17 +279,10 @@ def test_job_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -382,23 +297,16 @@ def test_job_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -411,24 +319,16 @@ def test_job_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -443,11 +343,11 @@ def test_job_service_client_client_options_credentials_file( def test_job_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + client = JobServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -459,11 +359,10 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job( - transport: str = "grpc", request_type=job_service.CreateCustomJobRequest -): +def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -472,13 +371,16 @@ def test_create_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_custom_job(request) @@ -490,11 +392,12 @@ def test_create_custom_job( assert args[0] == job_service.CreateCustomJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -504,27 +407,26 @@ def test_create_custom_job_from_dict(): @pytest.mark.asyncio -async def test_create_custom_job_async(transport: str = "grpc_asyncio"): +async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateCustomJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_custom_job(request) @@ -532,30 +434,37 @@ async def test_create_custom_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CreateCustomJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED +@pytest.mark.asyncio +async def test_create_custom_job_async_from_dict(): + await test_create_custom_job_async(request_type=dict) + + def test_create_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: call.return_value = gca_custom_job.CustomJob() client.create_custom_job(request) @@ -567,25 +476,28 @@ def test_create_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) + type(client.transport.create_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) await client.create_custom_job(request) @@ -596,24 +508,29 @@ async def test_create_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -621,43 +538,45 @@ def test_create_custom_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') def test_create_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_custom_job( job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), "__call__" - ) as call: + type(client.transport.create_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -665,30 +584,31 @@ async def test_create_custom_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') @pytest.mark.asyncio async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_custom_job( job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), + parent='parent_value', + custom_job=gca_custom_job.CustomJob(name='name_value'), ) -def test_get_custom_job( - transport: str = "grpc", request_type=job_service.GetCustomJobRequest -): +def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -696,12 +616,17 @@ def test_get_custom_job( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_custom_job(request) @@ -713,11 +638,12 @@ def test_get_custom_job( assert args[0] == job_service.GetCustomJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -727,25 +653,26 @@ def test_get_custom_job_from_dict(): @pytest.mark.asyncio -async def test_get_custom_job_async(transport: str = "grpc_asyncio"): +async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetCustomJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( + name='name_value', + display_name='display_name_value', + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_custom_job(request) @@ -753,28 +680,37 @@ async def test_get_custom_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.GetCustomJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED +@pytest.mark.asyncio +async def test_get_custom_job_async_from_dict(): + await test_get_custom_job_async(request_type=dict) + + def test_get_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: call.return_value = custom_job.CustomJob() client.get_custom_job(request) @@ -786,23 +722,28 @@ def test_get_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) await client.get_custom_job(request) @@ -813,81 +754,99 @@ async def test_get_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_custom_job(name="name_value",) + client.get_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", + job_service.GetCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + with mock.patch.object( + type(client.transport.get_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_custom_job(name="name_value",) + response = await client.get_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", + job_service.GetCustomJobRequest(), + name='name_value', ) -def test_list_custom_jobs( - transport: str = "grpc", request_type=job_service.ListCustomJobsRequest -): +def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -895,10 +854,13 @@ def test_list_custom_jobs( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_custom_jobs(request) @@ -910,9 +872,10 @@ def test_list_custom_jobs( assert args[0] == job_service.ListCustomJobsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_custom_jobs_from_dict(): @@ -920,21 +883,24 @@ def test_list_custom_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_custom_jobs_async(transport: str = "grpc_asyncio"): +async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListCustomJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListCustomJobsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_custom_jobs(request) @@ -942,24 +908,33 @@ async def test_list_custom_jobs_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.ListCustomJobsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_from_dict(): + await test_list_custom_jobs_async(request_type=dict) def test_list_custom_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: call.return_value = job_service.ListCustomJobsResponse() client.list_custom_jobs(request) @@ -971,23 +946,28 @@ def test_list_custom_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) await client.list_custom_jobs(request) @@ -998,81 +978,104 @@ async def test_list_custom_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_custom_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_custom_jobs(parent="parent_value",) + client.list_custom_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_custom_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", + job_service.ListCustomJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_custom_jobs(parent="parent_value",) + response = await client.list_custom_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", + job_service.ListCustomJobsRequest(), + parent='parent_value', ) def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1081,21 +1084,32 @@ def test_list_custom_jobs_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[], + next_page_token='def', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_custom_jobs(request={}) @@ -1103,14 +1117,18 @@ def test_list_custom_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in results) - + assert all(isinstance(i, custom_job.CustomJob) + for i in results) def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + with mock.patch.object( + type(client.transport.list_custom_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1119,30 +1137,40 @@ def test_list_custom_jobs_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[], + next_page_token='def', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1151,35 +1179,46 @@ async def test_list_custom_jobs_async_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[], + next_page_token='def', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in responses) - + assert all(isinstance(i, custom_job.CustomJob) + for i in responses) @pytest.mark.asyncio async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_custom_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1188,29 +1227,37 @@ async def test_list_custom_jobs_async_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token="abc", + next_page_token='abc', ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + custom_jobs=[], + next_page_token='def', ), job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + custom_jobs=[ + custom_job.CustomJob(), + ], + next_page_token='ghi', + ), + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + ], ), RuntimeError, ) pages = [] async for page_ in (await client.list_custom_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_custom_job( - transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest -): +def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1219,10 +1266,10 @@ def test_delete_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_custom_job(request) @@ -1241,22 +1288,23 @@ def test_delete_custom_job_from_dict(): @pytest.mark.asyncio -async def test_delete_custom_job_async(transport: str = "grpc_asyncio"): +async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteCustomJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_custom_job(request) @@ -1265,25 +1313,32 @@ async def test_delete_custom_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.DeleteCustomJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_custom_job_async_from_dict(): + await test_delete_custom_job_async(request_type=dict) + + def test_delete_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_custom_job(request) @@ -1294,25 +1349,28 @@ def test_delete_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_custom_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_custom_job(request) @@ -1323,85 +1381,101 @@ async def test_delete_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_custom_job(name="name_value",) + client.delete_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", + job_service.DeleteCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), "__call__" - ) as call: + type(client.transport.delete_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_custom_job(name="name_value",) + response = await client.delete_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", + job_service.DeleteCustomJobRequest(), + name='name_value', ) -def test_cancel_custom_job( - transport: str = "grpc", request_type=job_service.CancelCustomJobRequest -): +def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1410,8 +1484,8 @@ def test_cancel_custom_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1432,19 +1506,20 @@ def test_cancel_custom_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_custom_job_async(transport: str = "grpc_asyncio"): +async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelCustomJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelCustomJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1454,24 +1529,31 @@ async def test_cancel_custom_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CancelCustomJobRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_cancel_custom_job_async_from_dict(): + await test_cancel_custom_job_async(request_type=dict) + + def test_cancel_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: call.return_value = None client.cancel_custom_job(request) @@ -1483,22 +1565,27 @@ def test_cancel_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_custom_job(request) @@ -1510,83 +1597,99 @@ async def test_cancel_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_custom_job(name="name_value",) + client.cancel_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", + job_service.CancelCustomJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), "__call__" - ) as call: + type(client.transport.cancel_custom_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_custom_job(name="name_value",) + response = await client.cancel_custom_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", + job_service.CancelCustomJobRequest(), + name='name_value', ) -def test_create_data_labeling_job( - transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest -): +def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1595,19 +1698,28 @@ def test_create_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, - specialist_pools=["specialist_pools_value"], + + specialist_pools=['specialist_pools_value'], + ) response = client.create_data_labeling_job(request) @@ -1619,25 +1731,26 @@ def test_create_data_labeling_job( assert args[0] == job_service.CreateDataLabelingJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] def test_create_data_labeling_job_from_dict(): @@ -1645,33 +1758,32 @@ def test_create_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_create_data_labeling_job_async(transport: str = "grpc_asyncio"): +async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateDataLabelingJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) response = await client.create_data_labeling_job(request) @@ -1679,42 +1791,49 @@ async def test_create_data_labeling_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CreateDataLabelingJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_async_from_dict(): + await test_create_data_labeling_job_async(request_type=dict) def test_create_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: call.return_value = gca_data_labeling_job.DataLabelingJob() client.create_data_labeling_job(request) @@ -1726,25 +1845,28 @@ def test_create_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) + type(client.transport.create_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) await client.create_data_labeling_job(request) @@ -1755,24 +1877,29 @@ async def test_create_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -1780,45 +1907,45 @@ def test_create_data_labeling_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), "__call__" - ) as call: + type(client.transport.create_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -1826,32 +1953,31 @@ async def test_create_data_labeling_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + parent='parent_value', + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), ) -def test_get_data_labeling_job( - transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest -): +def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1860,19 +1986,28 @@ def test_get_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], + name='name_value', + + display_name='display_name_value', + + datasets=['datasets_value'], + labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", + + instruction_uri='instruction_uri_value', + + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, - specialist_pools=["specialist_pools_value"], + + specialist_pools=['specialist_pools_value'], + ) response = client.get_data_labeling_job(request) @@ -1884,25 +2019,26 @@ def test_get_data_labeling_job( assert args[0] == job_service.GetDataLabelingJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] def test_get_data_labeling_job_from_dict(): @@ -1910,33 +2046,32 @@ def test_get_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_get_data_labeling_job_async(transport: str = "grpc_asyncio"): +async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetDataLabelingJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( + name='name_value', + display_name='display_name_value', + datasets=['datasets_value'], + labeler_count=1375, + instruction_uri='instruction_uri_value', + inputs_schema_uri='inputs_schema_uri_value', + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=['specialist_pools_value'], + )) response = await client.get_data_labeling_job(request) @@ -1944,42 +2079,49 @@ async def test_get_data_labeling_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.GetDataLabelingJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.datasets == ["datasets_value"] + assert response.datasets == ['datasets_value'] assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" + assert response.instruction_uri == 'instruction_uri_value' - assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.inputs_schema_uri == 'inputs_schema_uri_value' assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] + assert response.specialist_pools == ['specialist_pools_value'] + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async_from_dict(): + await test_get_data_labeling_job_async(request_type=dict) def test_get_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: call.return_value = data_labeling_job.DataLabelingJob() client.get_data_labeling_job(request) @@ -1991,25 +2133,28 @@ def test_get_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) + type(client.transport.get_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) await client.get_data_labeling_job(request) @@ -2020,85 +2165,99 @@ async def test_get_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_data_labeling_job(name="name_value",) + client.get_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", + job_service.GetDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), "__call__" - ) as call: + type(client.transport.get_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_data_labeling_job(name="name_value",) + response = await client.get_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", + job_service.GetDataLabelingJobRequest(), + name='name_value', ) -def test_list_data_labeling_jobs( - transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest -): +def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2107,11 +2266,12 @@ def test_list_data_labeling_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_data_labeling_jobs(request) @@ -2123,9 +2283,10 @@ def test_list_data_labeling_jobs( assert args[0] == job_service.ListDataLabelingJobsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_data_labeling_jobs_from_dict(): @@ -2133,25 +2294,24 @@ def test_list_data_labeling_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_data_labeling_jobs_async(transport: str = "grpc_asyncio"): +async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListDataLabelingJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListDataLabelingJobsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_data_labeling_jobs(request) @@ -2159,26 +2319,33 @@ async def test_list_data_labeling_jobs_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.ListDataLabelingJobsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_from_dict(): + await test_list_data_labeling_jobs_async(request_type=dict) def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: call.return_value = job_service.ListDataLabelingJobsResponse() client.list_data_labeling_jobs(request) @@ -2190,25 +2357,28 @@ def test_list_data_labeling_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) + type(client.transport.list_data_labeling_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) await client.list_data_labeling_jobs(request) @@ -2219,87 +2389,104 @@ async def test_list_data_labeling_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_labeling_jobs(parent="parent_value",) + client.list_data_labeling_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs(parent="parent_value",) + response = await client.list_data_labeling_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", + job_service.ListDataLabelingJobsRequest(), + parent='parent_value', ) def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2308,14 +2495,17 @@ def test_list_data_labeling_jobs_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2328,7 +2518,9 @@ def test_list_data_labeling_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_data_labeling_jobs(request={}) @@ -2336,16 +2528,18 @@ def test_list_data_labeling_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) - + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in results) def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), "__call__" - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2354,14 +2548,17 @@ def test_list_data_labeling_jobs_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2372,20 +2569,19 @@ def test_list_data_labeling_jobs_pages(): RuntimeError, ) pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2394,14 +2590,17 @@ async def test_list_data_labeling_jobs_async_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2412,25 +2611,25 @@ async def test_list_data_labeling_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) - + assert all(isinstance(i, data_labeling_job.DataLabelingJob) + for i in responses) @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_data_labeling_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2439,14 +2638,17 @@ async def test_list_data_labeling_jobs_async_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", + data_labeling_jobs=[], + next_page_token='def', ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + ], + next_page_token='ghi', ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2459,15 +2661,14 @@ async def test_list_data_labeling_jobs_async_pages(): pages = [] async for page_ in (await client.list_data_labeling_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job( - transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest -): +def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2476,10 +2677,10 @@ def test_delete_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_data_labeling_job(request) @@ -2498,22 +2699,23 @@ def test_delete_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_delete_data_labeling_job_async(transport: str = "grpc_asyncio"): +async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteDataLabelingJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_data_labeling_job(request) @@ -2522,25 +2724,32 @@ async def test_delete_data_labeling_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.DeleteDataLabelingJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async_from_dict(): + await test_delete_data_labeling_job_async(request_type=dict) + + def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_data_labeling_job(request) @@ -2551,25 +2760,28 @@ def test_delete_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_data_labeling_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_data_labeling_job(request) @@ -2580,85 +2792,101 @@ async def test_delete_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_data_labeling_job(name="name_value",) + client.delete_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", + job_service.DeleteDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), "__call__" - ) as call: + type(client.transport.delete_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_data_labeling_job(name="name_value",) + response = await client.delete_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", + job_service.DeleteDataLabelingJobRequest(), + name='name_value', ) -def test_cancel_data_labeling_job( - transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest -): +def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2667,8 +2895,8 @@ def test_cancel_data_labeling_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2689,19 +2917,20 @@ def test_cancel_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_data_labeling_job_async(transport: str = "grpc_asyncio"): +async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelDataLabelingJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelDataLabelingJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -2711,24 +2940,31 @@ async def test_cancel_data_labeling_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CancelDataLabelingJobRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async_from_dict(): + await test_cancel_data_labeling_job_async(request_type=dict) + + def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: call.return_value = None client.cancel_data_labeling_job(request) @@ -2740,22 +2976,27 @@ def test_cancel_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_data_labeling_job(request) @@ -2767,84 +3008,99 @@ async def test_cancel_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_data_labeling_job(name="name_value",) + client.cancel_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", + job_service.CancelDataLabelingJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), "__call__" - ) as call: + type(client.transport.cancel_data_labeling_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job(name="name_value",) + response = await client.cancel_data_labeling_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", + job_service.CancelDataLabelingJobRequest(), + name='name_value', ) -def test_create_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): +def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2853,16 +3109,22 @@ def test_create_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_hyperparameter_tuning_job(request) @@ -2874,11 +3136,12 @@ def test_create_hyperparameter_tuning_job( assert args[0] == job_service.CreateHyperparameterTuningJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -2894,30 +3157,29 @@ def test_create_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): +async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateHyperparameterTuningJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_hyperparameter_tuning_job(request) @@ -2925,14 +3187,14 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = "grpc_asy assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -2943,18 +3205,25 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = "grpc_asy assert response.state == job_state.JobState.JOB_STATE_QUEUED +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async_from_dict(): + await test_create_hyperparameter_tuning_job_async(request_type=dict) + + def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() client.create_hyperparameter_tuning_job(request) @@ -2966,25 +3235,28 @@ def test_create_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) await client.create_hyperparameter_tuning_job(request) @@ -2995,26 +3267,29 @@ async def test_create_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -3022,51 +3297,45 @@ def test_create_hyperparameter_tuning_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.create_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -3074,36 +3343,31 @@ async def test_create_hyperparameter_tuning_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) + assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), + parent='parent_value', + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), ) -def test_get_hyperparameter_tuning_job( - transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest -): +def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3112,16 +3376,22 @@ def test_get_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_hyperparameter_tuning_job(request) @@ -3133,11 +3403,12 @@ def test_get_hyperparameter_tuning_job( assert args[0] == job_service.GetHyperparameterTuningJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3153,30 +3424,29 @@ def test_get_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): +async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetHyperparameterTuningJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( + name='name_value', + display_name='display_name_value', + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_hyperparameter_tuning_job(request) @@ -3184,14 +3454,14 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = "grpc_asynci assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.GetHyperparameterTuningJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.max_trial_count == 1609 @@ -3202,18 +3472,25 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = "grpc_asynci assert response.state == job_state.JobState.JOB_STATE_QUEUED +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async_from_dict(): + await test_get_hyperparameter_tuning_job_async(request_type=dict) + + def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() client.get_hyperparameter_tuning_job(request) @@ -3225,25 +3502,28 @@ def test_get_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) await client.get_hyperparameter_tuning_job(request) @@ -3254,86 +3534,99 @@ async def test_get_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job(name="name_value",) + client.get_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.get_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job(name="name_value",) + response = await client.get_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", + job_service.GetHyperparameterTuningJobRequest(), + name='name_value', ) -def test_list_hyperparameter_tuning_jobs( - transport: str = "grpc", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): +def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3342,11 +3635,12 @@ def test_list_hyperparameter_tuning_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_hyperparameter_tuning_jobs(request) @@ -3358,9 +3652,10 @@ def test_list_hyperparameter_tuning_jobs( assert args[0] == job_service.ListHyperparameterTuningJobsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_hyperparameter_tuning_jobs_from_dict(): @@ -3368,25 +3663,24 @@ def test_list_hyperparameter_tuning_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async(transport: str = "grpc_asyncio"): +async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListHyperparameterTuningJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListHyperparameterTuningJobsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_hyperparameter_tuning_jobs(request) @@ -3394,26 +3688,33 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = "grpc_asyn assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_from_dict(): + await test_list_hyperparameter_tuning_jobs_async(request_type=dict) def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: call.return_value = job_service.ListHyperparameterTuningJobsResponse() client.list_hyperparameter_tuning_jobs(request) @@ -3425,25 +3726,28 @@ def test_list_hyperparameter_tuning_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) await client.list_hyperparameter_tuning_jobs(request) @@ -3454,87 +3758,104 @@ async def test_list_hyperparameter_tuning_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs(parent="parent_value",) + client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) + response = await client.list_hyperparameter_tuning_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + job_service.ListHyperparameterTuningJobsRequest(), + parent='parent_value', ) def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3543,16 +3864,17 @@ def test_list_hyperparameter_tuning_jobs_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3565,7 +3887,9 @@ def test_list_hyperparameter_tuning_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_hyperparameter_tuning_jobs(request={}) @@ -3573,19 +3897,18 @@ def test_list_hyperparameter_tuning_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results - ) - + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results) def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3594,16 +3917,17 @@ def test_list_hyperparameter_tuning_jobs_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3614,20 +3938,19 @@ def test_list_hyperparameter_tuning_jobs_pages(): RuntimeError, ) pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3636,16 +3959,17 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3656,28 +3980,25 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses - ) - + assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3686,16 +4007,17 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", + hyperparameter_tuning_jobs=[], + next_page_token='def', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token="ghi", + next_page_token='ghi', ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3706,20 +4028,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( - await client.list_hyperparameter_tuning_jobs(request={}) - ).pages: + async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): +def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3728,10 +4046,10 @@ def test_delete_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_hyperparameter_tuning_job(request) @@ -3750,22 +4068,23 @@ def test_delete_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): +async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteHyperparameterTuningJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_hyperparameter_tuning_job(request) @@ -3774,25 +4093,32 @@ async def test_delete_hyperparameter_tuning_job_async(transport: str = "grpc_asy assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async_from_dict(): + await test_delete_hyperparameter_tuning_job_async(request_type=dict) + + def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_hyperparameter_tuning_job(request) @@ -3803,25 +4129,28 @@ def test_delete_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_hyperparameter_tuning_job(request) @@ -3832,86 +4161,101 @@ async def test_delete_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job(name="name_value",) + client.delete_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.delete_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job(name="name_value",) + response = await client.delete_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + job_service.DeleteHyperparameterTuningJobRequest(), + name='name_value', ) -def test_cancel_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): +def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3920,8 +4264,8 @@ def test_cancel_hyperparameter_tuning_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -3942,19 +4286,20 @@ def test_cancel_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): +async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelHyperparameterTuningJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelHyperparameterTuningJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -3964,24 +4309,31 @@ async def test_cancel_hyperparameter_tuning_job_async(transport: str = "grpc_asy assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async_from_dict(): + await test_cancel_hyperparameter_tuning_job_async(request_type=dict) + + def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: call.return_value = None client.cancel_hyperparameter_tuning_job(request) @@ -3993,22 +4345,27 @@ def test_cancel_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_hyperparameter_tuning_job(request) @@ -4020,83 +4377,99 @@ async def test_cancel_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job(name="name_value",) + client.cancel_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: + type(client.transport.cancel_hyperparameter_tuning_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job(name="name_value",) + response = await client.cancel_hyperparameter_tuning_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + job_service.CancelHyperparameterTuningJobRequest(), + name='name_value', ) -def test_create_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest -): +def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4105,15 +4478,20 @@ def test_create_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", + name='name_value', + + display_name='display_name_value', + + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.create_batch_prediction_job(request) @@ -4125,13 +4503,14 @@ def test_create_batch_prediction_job( assert args[0] == job_service.CreateBatchPredictionJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True @@ -4143,29 +4522,28 @@ def test_create_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_create_batch_prediction_job_async(transport: str = "grpc_asyncio"): +async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateBatchPredictionJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.create_batch_prediction_job(request) @@ -4173,34 +4551,41 @@ async def test_create_batch_prediction_job_async(transport: str = "grpc_asyncio" assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CreateBatchPredictionJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True assert response.state == job_state.JobState.JOB_STATE_QUEUED +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async_from_dict(): + await test_create_batch_prediction_job_async(request_type=dict) + + def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: call.return_value = gca_batch_prediction_job.BatchPredictionJob() client.create_batch_prediction_job(request) @@ -4212,25 +4597,28 @@ def test_create_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) + type(client.transport.create_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) await client.create_batch_prediction_job(request) @@ -4241,26 +4629,29 @@ async def test_create_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -4268,51 +4659,45 @@ def test_create_batch_prediction_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), "__call__" - ) as call: + type(client.transport.create_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -4320,36 +4705,31 @@ async def test_create_batch_prediction_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) + assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), + parent='parent_value', + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), ) -def test_get_batch_prediction_job( - transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest -): +def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4358,15 +4738,20 @@ def test_get_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", + name='name_value', + + display_name='display_name_value', + + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) response = client.get_batch_prediction_job(request) @@ -4378,13 +4763,14 @@ def test_get_batch_prediction_job( assert args[0] == job_service.GetBatchPredictionJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True @@ -4396,29 +4782,28 @@ def test_get_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_get_batch_prediction_job_async(transport: str = "grpc_asyncio"): +async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetBatchPredictionJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( + name='name_value', + display_name='display_name_value', + model='model_value', + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + )) response = await client.get_batch_prediction_job(request) @@ -4426,34 +4811,41 @@ async def test_get_batch_prediction_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.GetBatchPredictionJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.model == "model_value" + assert response.model == 'model_value' assert response.generate_explanation is True assert response.state == job_state.JobState.JOB_STATE_QUEUED +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async_from_dict(): + await test_get_batch_prediction_job_async(request_type=dict) + + def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: call.return_value = batch_prediction_job.BatchPredictionJob() client.get_batch_prediction_job(request) @@ -4465,25 +4857,28 @@ def test_get_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) + type(client.transport.get_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) await client.get_batch_prediction_job(request) @@ -4494,85 +4889,99 @@ async def test_get_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_batch_prediction_job(name="name_value",) + client.get_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", + job_service.GetBatchPredictionJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), "__call__" - ) as call: + type(client.transport.get_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_batch_prediction_job(name="name_value",) + response = await client.get_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", + job_service.GetBatchPredictionJobRequest(), + name='name_value', ) -def test_list_batch_prediction_jobs( - transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest -): +def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4581,11 +4990,12 @@ def test_list_batch_prediction_jobs( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_batch_prediction_jobs(request) @@ -4597,9 +5007,10 @@ def test_list_batch_prediction_jobs( assert args[0] == job_service.ListBatchPredictionJobsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_batch_prediction_jobs_from_dict(): @@ -4607,25 +5018,24 @@ def test_list_batch_prediction_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async(transport: str = "grpc_asyncio"): +async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListBatchPredictionJobsRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListBatchPredictionJobsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_batch_prediction_jobs(request) @@ -4633,26 +5043,33 @@ async def test_list_batch_prediction_jobs_async(transport: str = "grpc_asyncio") assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.ListBatchPredictionJobsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_from_dict(): + await test_list_batch_prediction_jobs_async(request_type=dict) def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: call.return_value = job_service.ListBatchPredictionJobsResponse() client.list_batch_prediction_jobs(request) @@ -4664,25 +5081,28 @@ def test_list_batch_prediction_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) await client.list_batch_prediction_jobs(request) @@ -4693,87 +5113,104 @@ async def test_list_batch_prediction_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_batch_prediction_jobs(parent="parent_value",) + client.list_batch_prediction_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs(parent="parent_value",) + response = await client.list_batch_prediction_jobs( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + job_service.ListBatchPredictionJobsRequest(), + parent='parent_value', ) def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -4782,14 +5219,17 @@ def test_list_batch_prediction_jobs_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -4802,7 +5242,9 @@ def test_list_batch_prediction_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_batch_prediction_jobs(request={}) @@ -4810,18 +5252,18 @@ def test_list_batch_prediction_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results - ) - + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in results) def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), "__call__" - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -4830,14 +5272,17 @@ def test_list_batch_prediction_jobs_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -4848,20 +5293,19 @@ def test_list_batch_prediction_jobs_pages(): RuntimeError, ) pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -4870,14 +5314,17 @@ async def test_list_batch_prediction_jobs_async_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -4888,27 +5335,25 @@ async def test_list_batch_prediction_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses - ) - + assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) + for i in responses) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_batch_prediction_jobs), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -4917,14 +5362,17 @@ async def test_list_batch_prediction_jobs_async_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token="abc", + next_page_token='abc', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", + batch_prediction_jobs=[], + next_page_token='def', ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token='ghi', ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -4937,15 +5385,14 @@ async def test_list_batch_prediction_jobs_async_pages(): pages = [] async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job( - transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest -): +def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4954,10 +5401,10 @@ def test_delete_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_batch_prediction_job(request) @@ -4976,22 +5423,23 @@ def test_delete_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async(transport: str = "grpc_asyncio"): +async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteBatchPredictionJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_batch_prediction_job(request) @@ -5000,25 +5448,32 @@ async def test_delete_batch_prediction_job_async(transport: str = "grpc_asyncio" assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.DeleteBatchPredictionJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async_from_dict(): + await test_delete_batch_prediction_job_async(request_type=dict) + + def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_batch_prediction_job(request) @@ -5029,25 +5484,28 @@ def test_delete_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_batch_prediction_job), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_batch_prediction_job(request) @@ -5058,85 +5516,101 @@ async def test_delete_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_batch_prediction_job(name="name_value",) + client.delete_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), "__call__" - ) as call: + type(client.transport.delete_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job(name="name_value",) + response = await client.delete_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", + job_service.DeleteBatchPredictionJobRequest(), + name='name_value', ) -def test_cancel_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest -): +def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5145,8 +5619,8 @@ def test_cancel_batch_prediction_job( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -5167,19 +5641,20 @@ def test_cancel_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async(transport: str = "grpc_asyncio"): +async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelBatchPredictionJobRequest): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelBatchPredictionJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -5189,24 +5664,31 @@ async def test_cancel_batch_prediction_job_async(transport: str = "grpc_asyncio" assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == job_service.CancelBatchPredictionJobRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async_from_dict(): + await test_cancel_batch_prediction_job_async(request_type=dict) + + def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: call.return_value = None client.cancel_batch_prediction_job(request) @@ -5218,22 +5700,27 @@ def test_cancel_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_batch_prediction_job(request) @@ -5245,75 +5732,92 @@ async def test_cancel_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_batch_prediction_job(name="name_value",) + client.cancel_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", + job_service.CancelBatchPredictionJobRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), "__call__" - ) as call: + type(client.transport.cancel_batch_prediction_job), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job(name="name_value",) + response = await client.cancel_batch_prediction_job( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", + job_service.CancelBatchPredictionJobRequest(), + name='name_value', ) @@ -5324,7 +5828,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -5343,7 +5848,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -5371,13 +5877,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) +@pytest.mark.parametrize("transport_class", [ + transports.JobServiceGrpcTransport, + transports.JobServiceGrpcAsyncIOTransport +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -5385,8 +5891,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.JobServiceGrpcTransport,) + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.JobServiceGrpcTransport, + ) def test_job_service_base_transport_error(): @@ -5394,15 +5905,13 @@ def test_job_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_job_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -5411,27 +5920,27 @@ def test_job_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_custom_job", - "get_custom_job", - "list_custom_jobs", - "delete_custom_job", - "cancel_custom_job", - "create_data_labeling_job", - "get_data_labeling_job", - "list_data_labeling_jobs", - "delete_data_labeling_job", - "cancel_data_labeling_job", - "create_hyperparameter_tuning_job", - "get_hyperparameter_tuning_job", - "list_hyperparameter_tuning_jobs", - "delete_hyperparameter_tuning_job", - "cancel_hyperparameter_tuning_job", - "create_batch_prediction_job", - "get_batch_prediction_job", - "list_batch_prediction_jobs", - "delete_batch_prediction_job", - "cancel_batch_prediction_job", - ) + 'create_custom_job', + 'get_custom_job', + 'list_custom_jobs', + 'delete_custom_job', + 'cancel_custom_job', + 'create_data_labeling_job', + 'get_data_labeling_job', + 'list_data_labeling_jobs', + 'delete_data_labeling_job', + 'cancel_data_labeling_job', + 'create_hyperparameter_tuning_job', + 'get_hyperparameter_tuning_job', + 'list_hyperparameter_tuning_jobs', + 'delete_hyperparameter_tuning_job', + 'cancel_hyperparameter_tuning_job', + 'create_batch_prediction_job', + 'get_batch_prediction_job', + 'list_batch_prediction_jobs', + 'delete_batch_prediction_job', + 'cancel_batch_prediction_job', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -5444,28 +5953,23 @@ def test_job_service_base_transport(): def test_job_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_job_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport() @@ -5474,11 +5978,11 @@ def test_job_service_base_transport_with_adc(): def test_job_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) JobServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -5486,70 +5990,62 @@ def test_job_service_auth_adc(): def test_job_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) - def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_job_service_host_with_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_job_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -5558,7 +6054,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -5574,27 +6070,27 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) +def test_job_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -5611,7 +6107,9 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -5620,12 +6118,16 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): def test_job_service_grpc_lro_client(): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -5633,36 +6135,36 @@ def test_job_service_grpc_lro_client(): def test_job_service_grpc_lro_async_client(): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client - def test_batch_prediction_job_path(): project = "squid" location = "clam" batch_prediction_job = "whelk" - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, location=location, batch_prediction_job=batch_prediction_job, - ) - actual = JobServiceClient.batch_prediction_job_path( - project, location, batch_prediction_job - ) + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) + actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) assert expected == actual def test_parse_batch_prediction_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", + } path = JobServiceClient.batch_prediction_job_path(**expected) @@ -5670,24 +6172,22 @@ def test_parse_batch_prediction_job_path(): actual = JobServiceClient.parse_batch_prediction_job_path(path) assert expected == actual - def test_custom_job_path(): project = "cuttlefish" location = "mussel" custom_job = "winkle" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) actual = JobServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", + } path = JobServiceClient.custom_job_path(**expected) @@ -5695,26 +6195,22 @@ def test_parse_custom_job_path(): actual = JobServiceClient.parse_custom_job_path(path) assert expected == actual - def test_data_labeling_job_path(): project = "squid" location = "clam" data_labeling_job = "whelk" - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) - actual = JobServiceClient.data_labeling_job_path( - project, location, data_labeling_job - ) + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) + actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) assert expected == actual def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", + } path = JobServiceClient.data_labeling_job_path(**expected) @@ -5722,24 +6218,22 @@ def test_parse_data_labeling_job_path(): actual = JobServiceClient.parse_data_labeling_job_path(path) assert expected == actual - def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = JobServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } path = JobServiceClient.dataset_path(**expected) @@ -5747,28 +6241,22 @@ def test_parse_dataset_path(): actual = JobServiceClient.parse_dataset_path(path) assert expected == actual - def test_hyperparameter_tuning_job_path(): project = "squid" location = "clam" hyperparameter_tuning_job = "whelk" - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) - actual = JobServiceClient.hyperparameter_tuning_job_path( - project, location, hyperparameter_tuning_job - ) + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) + actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) assert expected == actual def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "hyperparameter_tuning_job": "nudibranch", + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", + } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -5776,24 +6264,22 @@ def test_parse_hyperparameter_tuning_job_path(): actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = JobServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = JobServiceClient.model_path(**expected) @@ -5801,20 +6287,18 @@ def test_parse_model_path(): actual = JobServiceClient.parse_model_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = JobServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "clam", + } path = JobServiceClient.common_billing_account_path(**expected) @@ -5822,18 +6306,18 @@ def test_parse_common_billing_account_path(): actual = JobServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = JobServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "octopus", + } path = JobServiceClient.common_folder_path(**expected) @@ -5841,18 +6325,18 @@ def test_parse_common_folder_path(): actual = JobServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = JobServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nudibranch", + } path = JobServiceClient.common_organization_path(**expected) @@ -5860,18 +6344,18 @@ def test_parse_common_organization_path(): actual = JobServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = JobServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "mussel", + } path = JobServiceClient.common_project_path(**expected) @@ -5879,22 +6363,20 @@ def test_parse_common_project_path(): actual = JobServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = JobServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "scallop", + "location": "abalone", + } path = JobServiceClient.common_location_path(**expected) @@ -5906,19 +6388,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: transport_class = JobServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 01aece3a3b..f2c4c7cda9 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.migration_service import ( - MigrationServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.migration_service import ( - MigrationServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceClient from google.cloud.aiplatform_v1beta1.services.migration_service import pagers from google.cloud.aiplatform_v1beta1.services.migration_service import transports from google.cloud.aiplatform_v1beta1.types import migratable_resource @@ -57,11 +53,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -72,36 +64,17 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert ( - MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) + assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient] -) +@pytest.mark.parametrize("client_class", [MigrationServiceClient, MigrationServiceAsyncClient]) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -109,7 +82,7 @@ def test_migration_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_migration_service_client_get_transport_class(): @@ -120,44 +93,29 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - MigrationServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceClient), -) -@mock.patch.object( - MigrationServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceAsyncClient), -) -def test_migration_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) +def test_migration_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -173,7 +131,7 @@ def test_migration_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -189,7 +147,7 @@ def test_migration_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -209,15 +167,13 @@ def test_migration_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -230,66 +186,26 @@ def test_migration_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - MigrationServiceClient, - transports.MigrationServiceGrpcTransport, - "grpc", - "true", - ), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - MigrationServiceClient, - transports.MigrationServiceGrpcTransport, - "grpc", - "false", - ), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - MigrationServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceClient), -) -@mock.patch.object( - MigrationServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(MigrationServiceAsyncClient), -) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) +@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): patched.return_value = None client = client_class(client_options=options) @@ -312,21 +228,11 @@ def test_migration_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -336,9 +242,7 @@ def test_migration_service_client_mtls_env_auto( is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + expected_ssl_channel_creds = ssl_credentials_mock.return_value patched.return_value = None client = client_class() @@ -353,17 +257,10 @@ def test_migration_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -378,23 +275,16 @@ def test_migration_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_migration_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -407,24 +297,16 @@ def test_migration_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - ( - MigrationServiceAsyncClient, - transports.MigrationServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_migration_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -439,12 +321,10 @@ def test_migration_service_client_client_options_credentials_file( def test_migration_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -457,12 +337,10 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources( - transport: str = "grpc", - request_type=migration_service.SearchMigratableResourcesRequest, -): +def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -471,11 +349,12 @@ def test_search_migratable_resources( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.search_migratable_resources(request) @@ -487,9 +366,10 @@ def test_search_migratable_resources( assert args[0] == migration_service.SearchMigratableResourcesRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_search_migratable_resources_from_dict(): @@ -497,25 +377,24 @@ def test_search_migratable_resources_from_dict(): @pytest.mark.asyncio -async def test_search_migratable_resources_async(transport: str = "grpc_asyncio"): +async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = migration_service.SearchMigratableResourcesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( + next_page_token='next_page_token_value', + )) response = await client.search_migratable_resources(request) @@ -523,26 +402,33 @@ async def test_search_migratable_resources_async(transport: str = "grpc_asyncio" assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == migration_service.SearchMigratableResourcesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_from_dict(): + await test_search_migratable_resources_async(request_type=dict) def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -554,7 +440,10 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -566,15 +455,13 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse() - ) + type(client.transport.search_migratable_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) await client.search_migratable_resources(request) @@ -585,39 +472,49 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources(parent="parent_value",) + client.search_migratable_resources( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', ) @@ -629,24 +526,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - migration_service.SearchMigratableResourcesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources(parent="parent_value",) + response = await client.search_migratable_resources( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio @@ -659,17 +556,20 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + migration_service.SearchMigratableResourcesRequest(), + parent='parent_value', ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -678,14 +578,17 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -698,7 +601,9 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.search_migratable_resources(request={}) @@ -706,18 +611,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, migratable_resource.MigratableResource) for i in results - ) - + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in results) def test_search_migratable_resources_pages(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), "__call__" - ) as call: + type(client.transport.search_migratable_resources), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -726,14 +631,17 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -744,20 +652,19 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -766,14 +673,17 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -784,27 +694,25 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, migratable_resource.MigratableResource) for i in responses - ) - + assert all(isinstance(i, migratable_resource.MigratableResource) + for i in responses) @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.search_migratable_resources), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -813,14 +721,17 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token="abc", + next_page_token='abc', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], next_page_token="def", + migratable_resources=[], + next_page_token='def', ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[migratable_resource.MigratableResource(),], - next_page_token="ghi", + migratable_resources=[ + migratable_resource.MigratableResource(), + ], + next_page_token='ghi', ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -833,15 +744,14 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources( - transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest -): +def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -850,10 +760,10 @@ def test_batch_migrate_resources( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.batch_migrate_resources(request) @@ -872,22 +782,23 @@ def test_batch_migrate_resources_from_dict(): @pytest.mark.asyncio -async def test_batch_migrate_resources_async(transport: str = "grpc_asyncio"): +async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = migration_service.BatchMigrateResourcesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.batch_migrate_resources(request) @@ -896,25 +807,32 @@ async def test_batch_migrate_resources_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == migration_service.BatchMigrateResourcesRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_batch_migrate_resources_async_from_dict(): + await test_batch_migrate_resources_async(request_type=dict) + + def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.batch_migrate_resources(request) @@ -925,7 +843,10 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -937,15 +858,13 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.batch_migrate_resources), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.batch_migrate_resources(request) @@ -956,30 +875,29 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) # Establish that the underlying call was made with the expected @@ -987,33 +905,23 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].migrate_resource_requests == [ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ] + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) @@ -1025,25 +933,19 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), "__call__" - ) as call: + type(client.transport.batch_migrate_resources), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) # Establish that the underlying call was made with the expected @@ -1051,15 +953,9 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].migrate_resource_requests == [ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ] + assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] @pytest.mark.asyncio @@ -1073,14 +969,8 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent="parent_value", - migrate_resource_requests=[ - migration_service.MigrateResourceRequest( - migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( - endpoint="endpoint_value" - ) - ) - ], + parent='parent_value', + migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], ) @@ -1091,7 +981,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1110,7 +1001,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1138,16 +1030,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1155,8 +1044,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.MigrationServiceGrpcTransport, + ) def test_migration_service_base_transport_error(): @@ -1164,15 +1058,13 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1181,9 +1073,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "search_migratable_resources", - "batch_migrate_resources", - ) + 'search_migratable_resources', + 'batch_migrate_resources', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1196,28 +1088,23 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1226,11 +1113,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -1238,75 +1125,62 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) - def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_migration_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1315,7 +1189,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1331,30 +1205,27 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize( - "transport_class", - [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport, - ], -) -def test_migration_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +def test_migration_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1371,7 +1242,9 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1380,12 +1253,16 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1393,36 +1270,36 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client - def test_annotated_dataset_path(): project = "squid" dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( - project=project, dataset=dataset, annotated_dataset=annotated_dataset, - ) - actual = MigrationServiceClient.annotated_dataset_path( - project, dataset, annotated_dataset - ) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) + actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", + } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1430,24 +1307,22 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual - def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1455,22 +1330,20 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_dataset_path(): project = "squid" dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, - ) + expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", + "project": "whelk", + "dataset": "octopus", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1478,24 +1351,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_dataset_path(): project = "oyster" location = "nudibranch" dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", - "dataset": "nautilus", + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", + } path = MigrationServiceClient.dataset_path(**expected) @@ -1503,24 +1374,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual - def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", + "project": "clam", + "location": "whelk", + "model": "octopus", + } path = MigrationServiceClient.model_path(**expected) @@ -1528,24 +1397,22 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual - def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", + "project": "mussel", + "location": "winkle", + "model": "nautilus", + } path = MigrationServiceClient.model_path(**expected) @@ -1553,24 +1420,22 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual - def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format( - project=project, model=model, version=version, - ) + expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", + "project": "clam", + "model": "whelk", + "version": "octopus", + } path = MigrationServiceClient.version_path(**expected) @@ -1578,20 +1443,18 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", + "billing_account": "nudibranch", + } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1599,18 +1462,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", + "folder": "mussel", + } path = MigrationServiceClient.common_folder_path(**expected) @@ -1618,18 +1481,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", + "organization": "nautilus", + } path = MigrationServiceClient.common_organization_path(**expected) @@ -1637,18 +1500,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", + "project": "abalone", + } path = MigrationServiceClient.common_project_path(**expected) @@ -1656,22 +1519,20 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", + "project": "whelk", + "location": "octopus", + } path = MigrationServiceClient.common_location_path(**expected) @@ -1683,19 +1544,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.MigrationServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.MigrationServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index d3c450ffb7..af1e117cc3 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -35,9 +35,7 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.model_service import ( - ModelServiceAsyncClient, -) +from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceAsyncClient from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.services.model_service import transports @@ -67,11 +65,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -82,30 +76,17 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) + assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi @pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -113,7 +94,7 @@ def test_model_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_model_service_client_get_transport_class(): @@ -124,42 +105,29 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) -def test_model_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -175,7 +143,7 @@ def test_model_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -191,7 +159,7 @@ def test_model_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -211,15 +179,13 @@ def test_model_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -232,54 +198,26 @@ def test_model_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): patched.return_value = None client = client_class(client_options=options) @@ -302,21 +240,11 @@ def test_model_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -326,9 +254,7 @@ def test_model_service_client_mtls_env_auto( is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + expected_ssl_channel_creds = ssl_credentials_mock.return_value patched.return_value = None client = client_class() @@ -343,17 +269,10 @@ def test_model_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -368,23 +287,16 @@ def test_model_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -397,24 +309,16 @@ def test_model_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -429,11 +333,11 @@ def test_model_service_client_client_options_credentials_file( def test_model_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + client = ModelServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -445,11 +349,10 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model( - transport: str = "grpc", request_type=model_service.UploadModelRequest -): +def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -457,9 +360,11 @@ def test_upload_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.upload_model(request) @@ -478,20 +383,23 @@ def test_upload_model_from_dict(): @pytest.mark.asyncio -async def test_upload_model_async(transport: str = "grpc_asyncio"): +async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UploadModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.UploadModelRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.upload_model(request) @@ -500,23 +408,32 @@ async def test_upload_model_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.UploadModelRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_upload_model_async_from_dict(): + await test_upload_model_async(request_type=dict) + + def test_upload_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.upload_model(request) @@ -527,23 +444,28 @@ def test_upload_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.upload_model(request) @@ -554,21 +476,29 @@ async def test_upload_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_upload_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -576,40 +506,47 @@ def test_upload_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') def test_upload_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.upload_model( model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) @pytest.mark.asyncio async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + with mock.patch.object( + type(client.transport.upload_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -617,28 +554,31 @@ async def test_upload_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') @pytest.mark.asyncio async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.upload_model( model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), + parent='parent_value', + model=gca_model.Model(name='name_value'), ) -def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): +def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -646,21 +586,31 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + ) response = client.get_model(request) @@ -672,33 +622,28 @@ def test_get_model(transport: str = "grpc", request_type=model_service.GetModelR assert args[0] == model_service.GetModelRequest() # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_get_model_from_dict(): @@ -706,38 +651,33 @@ def test_get_model_from_dict(): @pytest.mark.asyncio -async def test_get_model_async(transport: str = "grpc_asyncio"): +async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) response = await client.get_model(request) @@ -745,48 +685,51 @@ async def test_get_model_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.GetModelRequest() # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' + + assert response.display_name == 'display_name_value' + + assert response.description == 'description_value' - assert response.display_name == "display_name_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.description == "description_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.artifact_uri == "artifact_uri_value" + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.etag == 'etag_value' - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - assert response.etag == "etag_value" +@pytest.mark.asyncio +async def test_get_model_async_from_dict(): + await test_get_model_async(request_type=dict) def test_get_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: call.return_value = model.Model() client.get_model(request) @@ -798,20 +741,27 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -823,79 +773,99 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model(name="name_value",) + client.get_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model( - model_service.GetModelRequest(), name="name_value", + model_service.GetModelRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.get_model), "__call__") as call: + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model.Model() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model(name="name_value",) + response = await client.get_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model( - model_service.GetModelRequest(), name="name_value", + model_service.GetModelRequest(), + name='name_value', ) -def test_list_models( - transport: str = "grpc", request_type=model_service.ListModelsRequest -): +def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -903,10 +873,13 @@ def test_list_models( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_models(request) @@ -918,9 +891,10 @@ def test_list_models( assert args[0] == model_service.ListModelsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_models_from_dict(): @@ -928,21 +902,24 @@ def test_list_models_from_dict(): @pytest.mark.asyncio -async def test_list_models_async(transport: str = "grpc_asyncio"): +async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse(next_page_token="next_page_token_value",) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_models(request) @@ -950,24 +927,33 @@ async def test_list_models_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.ListModelsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_models_async_from_dict(): + await test_list_models_async(request_type=dict) def test_list_models_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: call.return_value = model_service.ListModelsResponse() client.list_models(request) @@ -979,23 +965,28 @@ def test_list_models_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) await client.list_models(request) @@ -1006,98 +997,138 @@ async def test_list_models_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_models_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_models(parent="parent_value",) + client.list_models( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_models_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_models( - model_service.ListModelsRequest(), parent="parent_value", + model_service.ListModelsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_models(parent="parent_value",) + response = await client.list_models( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_models( - model_service.ListModelsRequest(), parent="parent_value", + model_service.ListModelsRequest(), + parent='parent_value', ) def test_list_models_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_models(request={}) @@ -1105,96 +1136,147 @@ def test_list_models_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model.Model) for i in results) - + assert all(isinstance(i, model.Model) + for i in results) def test_list_models_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.list_models), "__call__") as call: + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_models_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) for i in responses) - + assert all(isinstance(i, model.Model) + for i in responses) @pytest.mark.asyncio async def test_list_models_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock - ) as call: + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', ), - model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_models(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_update_model( - transport: str = "grpc", request_type=model_service.UpdateModelRequest -): +def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1202,21 +1284,31 @@ def test_update_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", + name='name_value', + + display_name='display_name_value', + + description='description_value', + + metadata_schema_uri='metadata_schema_uri_value', + + training_pipeline='training_pipeline_value', + + artifact_uri='artifact_uri_value', + + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + + supported_input_storage_formats=['supported_input_storage_formats_value'], + + supported_output_storage_formats=['supported_output_storage_formats_value'], + + etag='etag_value', + ) response = client.update_model(request) @@ -1228,33 +1320,28 @@ def test_update_model( assert args[0] == model_service.UpdateModelRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' def test_update_model_from_dict(): @@ -1262,38 +1349,33 @@ def test_update_model_from_dict(): @pytest.mark.asyncio -async def test_update_model_async(transport: str = "grpc_asyncio"): +async def test_update_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.UpdateModelRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( + name='name_value', + display_name='display_name_value', + description='description_value', + metadata_schema_uri='metadata_schema_uri_value', + training_pipeline='training_pipeline_value', + artifact_uri='artifact_uri_value', + supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], + supported_input_storage_formats=['supported_input_storage_formats_value'], + supported_output_storage_formats=['supported_output_storage_formats_value'], + etag='etag_value', + )) response = await client.update_model(request) @@ -1301,48 +1383,51 @@ async def test_update_model_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.UpdateModelRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.description == "description_value" + assert response.description == 'description_value' - assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.metadata_schema_uri == 'metadata_schema_uri_value' - assert response.training_pipeline == "training_pipeline_value" + assert response.training_pipeline == 'training_pipeline_value' - assert response.artifact_uri == "artifact_uri_value" + assert response.artifact_uri == 'artifact_uri_value' - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] + assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] + assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] + assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] - assert response.etag == "etag_value" + assert response.etag == 'etag_value' + + +@pytest.mark.asyncio +async def test_update_model_async_from_dict(): + await test_update_model_async(request_type=dict) def test_update_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" + request.model.name = 'model.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: call.return_value = gca_model.Model() client.update_model(request) @@ -1354,20 +1439,27 @@ def test_update_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" + request.model.name = 'model.name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) await client.update_model(request) @@ -1379,22 +1471,29 @@ async def test_update_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'model.name=model.name/value', + ) in kw['metadata'] def test_update_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1402,30 +1501,36 @@ def test_update_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @pytest.mark.asyncio async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.update_model), "__call__") as call: + with mock.patch.object( + type(client.transport.update_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() @@ -1433,8 +1538,8 @@ async def test_update_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1442,30 +1547,31 @@ async def test_update_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") + assert args[0].model == gca_model.Model(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + model=gca_model.Model(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) -def test_delete_model( - transport: str = "grpc", request_type=model_service.DeleteModelRequest -): +def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1473,9 +1579,11 @@ def test_delete_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_model(request) @@ -1494,20 +1602,23 @@ def test_delete_model_from_dict(): @pytest.mark.asyncio -async def test_delete_model_async(transport: str = "grpc_asyncio"): +async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.DeleteModelRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_model(request) @@ -1516,23 +1627,32 @@ async def test_delete_model_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.DeleteModelRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_model_async_from_dict(): + await test_delete_model_async(request_type=dict) + + def test_delete_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_model(request) @@ -1543,23 +1663,28 @@ def test_delete_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_model(request) @@ -1570,81 +1695,101 @@ async def test_delete_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model(name="name_value",) + client.delete_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model( - model_service.DeleteModelRequest(), name="name_value", + model_service.DeleteModelRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + with mock.patch.object( + type(client.transport.delete_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_model(name="name_value",) + response = await client.delete_model( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model( - model_service.DeleteModelRequest(), name="name_value", + model_service.DeleteModelRequest(), + name='name_value', ) -def test_export_model( - transport: str = "grpc", request_type=model_service.ExportModelRequest -): +def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1652,9 +1797,11 @@ def test_export_model( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.export_model(request) @@ -1673,20 +1820,23 @@ def test_export_model_from_dict(): @pytest.mark.asyncio -async def test_export_model_async(transport: str = "grpc_asyncio"): +async def test_export_model_async(transport: str = 'grpc_asyncio', request_type=model_service.ExportModelRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.ExportModelRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.export_model(request) @@ -1695,23 +1845,32 @@ async def test_export_model_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.ExportModelRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_export_model_async_from_dict(): + await test_export_model_async(request_type=dict) + + def test_export_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.export_model(request) @@ -1722,23 +1881,28 @@ def test_export_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.export_model(request) @@ -1749,24 +1913,29 @@ async def test_export_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_export_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) # Establish that the underlying call was made with the expected @@ -1774,47 +1943,47 @@ def test_export_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') def test_export_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_model( model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) @pytest.mark.asyncio async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client.transport.export_model), "__call__") as call: + with mock.patch.object( + type(client.transport.export_model), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) # Establish that the underlying call was made with the expected @@ -1822,34 +1991,31 @@ async def test_export_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') @pytest.mark.asyncio async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_model( model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), + name='name_value', + output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), ) -def test_get_model_evaluation( - transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest -): +def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1858,13 +2024,16 @@ def test_get_model_evaluation( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + + slice_dimensions=['slice_dimensions_value'], + ) response = client.get_model_evaluation(request) @@ -1876,13 +2045,14 @@ def test_get_model_evaluation( assert args[0] == model_service.GetModelEvaluationRequest() # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' - assert response.slice_dimensions == ["slice_dimensions_value"] + assert response.slice_dimensions == ['slice_dimensions_value'] def test_get_model_evaluation_from_dict(): @@ -1890,27 +2060,26 @@ def test_get_model_evaluation_from_dict(): @pytest.mark.asyncio -async def test_get_model_evaluation_async(transport: str = "grpc_asyncio"): +async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + slice_dimensions=['slice_dimensions_value'], + )) response = await client.get_model_evaluation(request) @@ -1918,30 +2087,37 @@ async def test_get_model_evaluation_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.GetModelEvaluationRequest() # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" + assert response.name == 'name_value' + + assert response.metrics_schema_uri == 'metrics_schema_uri_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.slice_dimensions == ['slice_dimensions_value'] - assert response.slice_dimensions == ["slice_dimensions_value"] + +@pytest.mark.asyncio +async def test_get_model_evaluation_async_from_dict(): + await test_get_model_evaluation_async(request_type=dict) def test_get_model_evaluation_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: call.return_value = model_evaluation.ModelEvaluation() client.get_model_evaluation(request) @@ -1953,25 +2129,28 @@ def test_get_model_evaluation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) + type(client.transport.get_model_evaluation), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) await client.get_model_evaluation(request) @@ -1982,85 +2161,99 @@ async def test_get_model_evaluation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_evaluation_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation(name="name_value",) + client.get_model_evaluation( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", + model_service.GetModelEvaluationRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), "__call__" - ) as call: + type(client.transport.get_model_evaluation), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation(name="name_value",) + response = await client.get_model_evaluation( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", + model_service.GetModelEvaluationRequest(), + name='name_value', ) -def test_list_model_evaluations( - transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest -): +def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2069,11 +2262,12 @@ def test_list_model_evaluations( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_model_evaluations(request) @@ -2085,9 +2279,10 @@ def test_list_model_evaluations( assert args[0] == model_service.ListModelEvaluationsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_model_evaluations_from_dict(): @@ -2095,25 +2290,24 @@ def test_list_model_evaluations_from_dict(): @pytest.mark.asyncio -async def test_list_model_evaluations_async(transport: str = "grpc_asyncio"): +async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationsRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_model_evaluations(request) @@ -2121,26 +2315,33 @@ async def test_list_model_evaluations_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.ListModelEvaluationsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_from_dict(): + await test_list_model_evaluations_async(request_type=dict) def test_list_model_evaluations_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: call.return_value = model_service.ListModelEvaluationsResponse() client.list_model_evaluations(request) @@ -2152,25 +2353,28 @@ def test_list_model_evaluations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) + type(client.transport.list_model_evaluations), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) await client.list_model_evaluations(request) @@ -2181,87 +2385,104 @@ async def test_list_model_evaluations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_model_evaluations_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluations(parent="parent_value",) + client.list_model_evaluations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", + model_service.ListModelEvaluationsRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluations(parent="parent_value",) + response = await client.list_model_evaluations( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", + model_service.ListModelEvaluationsRequest(), + parent='parent_value', ) def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2270,14 +2491,17 @@ def test_list_model_evaluations_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2290,7 +2514,9 @@ def test_list_model_evaluations_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_model_evaluations(request={}) @@ -2298,16 +2524,18 @@ def test_list_model_evaluations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) - + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in results) def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), "__call__" - ) as call: + type(client.transport.list_model_evaluations), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2316,14 +2544,17 @@ def test_list_model_evaluations_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2334,20 +2565,19 @@ def test_list_model_evaluations_pages(): RuntimeError, ) pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2356,14 +2586,17 @@ async def test_list_model_evaluations_async_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2374,25 +2607,25 @@ async def test_list_model_evaluations_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) - + assert all(isinstance(i, model_evaluation.ModelEvaluation) + for i in responses) @pytest.mark.asyncio async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluations), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2401,14 +2634,17 @@ async def test_list_model_evaluations_async_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", + model_evaluations=[], + next_page_token='def', ), model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", + model_evaluations=[ + model_evaluation.ModelEvaluation(), + ], + next_page_token='ghi', ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2421,15 +2657,14 @@ async def test_list_model_evaluations_async_pages(): pages = [] async for page_ in (await client.list_model_evaluations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice( - transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest -): +def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2438,11 +2673,14 @@ def test_get_model_evaluation_slice( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", + name='name_value', + + metrics_schema_uri='metrics_schema_uri_value', + ) response = client.get_model_evaluation_slice(request) @@ -2454,11 +2692,12 @@ def test_get_model_evaluation_slice( assert args[0] == model_service.GetModelEvaluationSliceRequest() # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.metrics_schema_uri == 'metrics_schema_uri_value' def test_get_model_evaluation_slice_from_dict(): @@ -2466,25 +2705,25 @@ def test_get_model_evaluation_slice_from_dict(): @pytest.mark.asyncio -async def test_get_model_evaluation_slice_async(transport: str = "grpc_asyncio"): +async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationSliceRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationSliceRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( + name='name_value', + metrics_schema_uri='metrics_schema_uri_value', + )) response = await client.get_model_evaluation_slice(request) @@ -2492,28 +2731,35 @@ async def test_get_model_evaluation_slice_async(transport: str = "grpc_asyncio") assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.GetModelEvaluationSliceRequest() # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" + assert response.name == 'name_value' + + assert response.metrics_schema_uri == 'metrics_schema_uri_value' + - assert response.metrics_schema_uri == "metrics_schema_uri_value" +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async_from_dict(): + await test_get_model_evaluation_slice_async(request_type=dict) def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: call.return_value = model_evaluation_slice.ModelEvaluationSlice() client.get_model_evaluation_slice(request) @@ -2525,25 +2771,28 @@ def test_get_model_evaluation_slice_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) + type(client.transport.get_model_evaluation_slice), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) await client.get_model_evaluation_slice(request) @@ -2554,85 +2803,99 @@ async def test_get_model_evaluation_slice_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation_slice(name="name_value",) + client.get_model_evaluation_slice( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", + model_service.GetModelEvaluationSliceRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), "__call__" - ) as call: + type(client.transport.get_model_evaluation_slice), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice(name="name_value",) + response = await client.get_model_evaluation_slice( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", + model_service.GetModelEvaluationSliceRequest(), + name='name_value', ) -def test_list_model_evaluation_slices( - transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest -): +def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2641,11 +2904,12 @@ def test_list_model_evaluation_slices( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_model_evaluation_slices(request) @@ -2657,9 +2921,10 @@ def test_list_model_evaluation_slices( assert args[0] == model_service.ListModelEvaluationSlicesRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_model_evaluation_slices_from_dict(): @@ -2667,25 +2932,24 @@ def test_list_model_evaluation_slices_from_dict(): @pytest.mark.asyncio -async def test_list_model_evaluation_slices_async(transport: str = "grpc_asyncio"): +async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationSlicesRequest): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationSlicesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( + next_page_token='next_page_token_value', + )) response = await client.list_model_evaluation_slices(request) @@ -2693,26 +2957,33 @@ async def test_list_model_evaluation_slices_async(transport: str = "grpc_asyncio assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == model_service.ListModelEvaluationSlicesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_from_dict(): + await test_list_model_evaluation_slices_async(request_type=dict) def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: call.return_value = model_service.ListModelEvaluationSlicesResponse() client.list_model_evaluation_slices(request) @@ -2724,25 +2995,28 @@ def test_list_model_evaluation_slices_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) + type(client.transport.list_model_evaluation_slices), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) await client.list_model_evaluation_slices(request) @@ -2753,87 +3027,104 @@ async def test_list_model_evaluation_slices_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluation_slices(parent="parent_value",) + client.list_model_evaluation_slices( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices(parent="parent_value",) + response = await client.list_model_evaluation_slices( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + model_service.ListModelEvaluationSlicesRequest(), + parent='parent_value', ) def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -2842,16 +3133,17 @@ def test_list_model_evaluation_slices_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -2864,7 +3156,9 @@ def test_list_model_evaluation_slices_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_model_evaluation_slices(request={}) @@ -2872,18 +3166,18 @@ def test_list_model_evaluation_slices_pager(): results = [i for i in pager] assert len(results) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results - ) - + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in results) def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), "__call__" - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -2892,16 +3186,17 @@ def test_list_model_evaluation_slices_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -2912,20 +3207,19 @@ def test_list_model_evaluation_slices_pages(): RuntimeError, ) pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -2934,16 +3228,17 @@ async def test_list_model_evaluation_slices_async_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -2954,28 +3249,25 @@ async def test_list_model_evaluation_slices_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses - ) - + assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses) @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_model_evaluation_slices), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -2984,16 +3276,17 @@ async def test_list_model_evaluation_slices_async_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="abc", + next_page_token='abc', ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", + model_evaluation_slices=[], + next_page_token='def', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token="ghi", + next_page_token='ghi', ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3004,11 +3297,9 @@ async def test_list_model_evaluation_slices_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( - await client.list_model_evaluation_slices(request={}) - ).pages: + async for page_ in (await client.list_model_evaluation_slices(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token @@ -3019,7 +3310,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3038,7 +3330,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -3066,13 +3359,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) +@pytest.mark.parametrize("transport_class", [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3080,8 +3373,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelServiceGrpcTransport, + ) def test_model_service_base_transport_error(): @@ -3089,15 +3387,13 @@ def test_model_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3106,17 +3402,17 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "upload_model", - "get_model", - "list_models", - "update_model", - "delete_model", - "export_model", - "get_model_evaluation", - "list_model_evaluations", - "get_model_evaluation_slice", - "list_model_evaluation_slices", - ) + 'upload_model', + 'get_model', + 'list_models', + 'update_model', + 'delete_model', + 'export_model', + 'get_model_evaluation', + 'list_model_evaluations', + 'get_model_evaluation_slice', + 'list_model_evaluation_slices', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3129,28 +3425,23 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -3159,11 +3450,11 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) ModelServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -3171,70 +3462,62 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) - def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_model_service_host_with_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_model_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3243,7 +3526,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3259,27 +3542,27 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3296,7 +3579,9 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3305,12 +3590,16 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): def test_model_service_grpc_lro_client(): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3318,34 +3607,36 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client - def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = ModelServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = ModelServiceClient.endpoint_path(**expected) @@ -3353,24 +3644,22 @@ def test_parse_endpoint_path(): actual = ModelServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = ModelServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = ModelServiceClient.model_path(**expected) @@ -3378,28 +3667,24 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual - def test_model_evaluation_path(): project = "squid" location = "clam" model = "whelk" evaluation = "octopus" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( - project=project, location=location, model=model, evaluation=evaluation, - ) - actual = ModelServiceClient.model_evaluation_path( - project, location, model, evaluation - ) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) + actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) assert expected == actual def test_parse_model_evaluation_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "model": "cuttlefish", - "evaluation": "mussel", + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", + } path = ModelServiceClient.model_evaluation_path(**expected) @@ -3407,7 +3692,6 @@ def test_parse_model_evaluation_path(): actual = ModelServiceClient.parse_model_evaluation_path(path) assert expected == actual - def test_model_evaluation_slice_path(): project = "winkle" location = "nautilus" @@ -3415,26 +3699,19 @@ def test_model_evaluation_slice_path(): evaluation = "abalone" slice = "squid" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( - project=project, - location=location, - model=model, - evaluation=evaluation, - slice=slice, - ) - actual = ModelServiceClient.model_evaluation_slice_path( - project, location, model, evaluation, slice - ) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) + actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) assert expected == actual def test_parse_model_evaluation_slice_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - "evaluation": "oyster", - "slice": "nudibranch", + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", + } path = ModelServiceClient.model_evaluation_slice_path(**expected) @@ -3442,26 +3719,22 @@ def test_parse_model_evaluation_slice_path(): actual = ModelServiceClient.parse_model_evaluation_slice_path(path) assert expected == actual - def test_training_pipeline_path(): project = "cuttlefish" location = "mussel" training_pipeline = "winkle" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = ModelServiceClient.training_pipeline_path( - project, location, training_pipeline - ) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "nautilus", - "location": "scallop", - "training_pipeline": "abalone", + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", + } path = ModelServiceClient.training_pipeline_path(**expected) @@ -3469,20 +3742,18 @@ def test_parse_training_pipeline_path(): actual = ModelServiceClient.parse_training_pipeline_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "clam", + } path = ModelServiceClient.common_billing_account_path(**expected) @@ -3490,18 +3761,18 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "octopus", + } path = ModelServiceClient.common_folder_path(**expected) @@ -3509,18 +3780,18 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "nudibranch", + } path = ModelServiceClient.common_organization_path(**expected) @@ -3528,18 +3799,18 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = ModelServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "mussel", + } path = ModelServiceClient.common_project_path(**expected) @@ -3547,22 +3818,20 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "scallop", + "location": "abalone", + } path = ModelServiceClient.common_location_path(**expected) @@ -3574,19 +3843,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 8a60b0a966..7ea561790e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceClient from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports from google.cloud.aiplatform_v1beta1.types import deployed_model_ref @@ -53,9 +49,7 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) +from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import any_pb2 as gp_any # type: ignore @@ -73,11 +67,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -88,35 +78,17 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) + assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] -) +@pytest.mark.parametrize("client_class", [PipelineServiceClient, PipelineServiceAsyncClient]) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -124,7 +96,7 @@ def test_pipeline_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_pipeline_service_client_get_transport_class(): @@ -135,44 +107,29 @@ def test_pipeline_service_client_get_transport_class(): assert transport == transports.PipelineServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) -def test_pipeline_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) +def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -188,7 +145,7 @@ def test_pipeline_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -204,7 +161,7 @@ def test_pipeline_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -224,15 +181,13 @@ def test_pipeline_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -245,66 +200,26 @@ def test_pipeline_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "true", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "false", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) +@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): patched.return_value = None client = client_class(client_options=options) @@ -327,21 +242,11 @@ def test_pipeline_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -351,9 +256,7 @@ def test_pipeline_service_client_mtls_env_auto( is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + expected_ssl_channel_creds = ssl_credentials_mock.return_value patched.return_value = None client = client_class() @@ -368,17 +271,10 @@ def test_pipeline_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -393,23 +289,16 @@ def test_pipeline_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -422,24 +311,16 @@ def test_pipeline_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -454,12 +335,10 @@ def test_pipeline_service_client_client_options_credentials_file( def test_pipeline_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = PipelineServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -472,11 +351,10 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest -): +def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -485,14 +363,18 @@ def test_create_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) response = client.create_training_pipeline(request) @@ -504,13 +386,14 @@ def test_create_training_pipeline( assert args[0] == pipeline_service.CreateTrainingPipelineRequest() # Establish that the response is the type that we expect. + assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -520,28 +403,27 @@ def test_create_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_create_training_pipeline_async(transport: str = "grpc_asyncio"): +async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CreateTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CreateTrainingPipelineRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) response = await client.create_training_pipeline(request) @@ -549,32 +431,39 @@ async def test_create_training_pipeline_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED +@pytest.mark.asyncio +async def test_create_training_pipeline_async_from_dict(): + await test_create_training_pipeline_async(request_type=dict) + + def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: call.return_value = gca_training_pipeline.TrainingPipeline() client.create_training_pipeline(request) @@ -586,25 +475,28 @@ def test_create_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) + type(client.transport.create_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) await client.create_training_pipeline(request) @@ -615,24 +507,29 @@ async def test_create_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -640,45 +537,45 @@ def test_create_training_pipeline_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), "__call__" - ) as call: + type(client.transport.create_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -686,32 +583,31 @@ async def test_create_training_pipeline_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') @pytest.mark.asyncio async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + parent='parent_value', + training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), ) -def test_get_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest -): +def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -720,14 +616,18 @@ def test_get_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", + name='name_value', + + display_name='display_name_value', + + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) response = client.get_training_pipeline(request) @@ -739,13 +639,14 @@ def test_get_training_pipeline( assert args[0] == pipeline_service.GetTrainingPipelineRequest() # Establish that the response is the type that we expect. + assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -755,28 +656,27 @@ def test_get_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_get_training_pipeline_async(transport: str = "grpc_asyncio"): +async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.GetTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.GetTrainingPipelineRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( + name='name_value', + display_name='display_name_value', + training_task_definition='training_task_definition_value', + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + )) response = await client.get_training_pipeline(request) @@ -784,32 +684,39 @@ async def test_get_training_pipeline_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == pipeline_service.GetTrainingPipelineRequest() # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' - assert response.training_task_definition == "training_task_definition_value" + assert response.training_task_definition == 'training_task_definition_value' assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED +@pytest.mark.asyncio +async def test_get_training_pipeline_async_from_dict(): + await test_get_training_pipeline_async(request_type=dict) + + def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: call.return_value = training_pipeline.TrainingPipeline() client.get_training_pipeline(request) @@ -821,25 +728,28 @@ def test_get_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) + type(client.transport.get_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) await client.get_training_pipeline(request) @@ -850,85 +760,99 @@ async def test_get_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_training_pipeline(name="name_value",) + client.get_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), "__call__" - ) as call: + type(client.transport.get_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_training_pipeline(name="name_value",) + response = await client.get_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", + pipeline_service.GetTrainingPipelineRequest(), + name='name_value', ) -def test_list_training_pipelines( - transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest -): +def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -937,11 +861,12 @@ def test_list_training_pipelines( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_training_pipelines(request) @@ -953,9 +878,10 @@ def test_list_training_pipelines( assert args[0] == pipeline_service.ListTrainingPipelinesRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_training_pipelines_from_dict(): @@ -963,25 +889,24 @@ def test_list_training_pipelines_from_dict(): @pytest.mark.asyncio -async def test_list_training_pipelines_async(transport: str = "grpc_asyncio"): +async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.ListTrainingPipelinesRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.ListTrainingPipelinesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( + next_page_token='next_page_token_value', + )) response = await client.list_training_pipelines(request) @@ -989,26 +914,33 @@ async def test_list_training_pipelines_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_from_dict(): + await test_list_training_pipelines_async(request_type=dict) def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: call.return_value = pipeline_service.ListTrainingPipelinesResponse() client.list_training_pipelines(request) @@ -1020,25 +952,28 @@ def test_list_training_pipelines_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) + type(client.transport.list_training_pipelines), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) await client.list_training_pipelines(request) @@ -1049,87 +984,104 @@ async def test_list_training_pipelines_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_training_pipelines_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_training_pipelines(parent="parent_value",) + client.list_training_pipelines( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', ) @pytest.mark.asyncio async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_training_pipelines(parent="parent_value",) + response = await client.list_training_pipelines( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + pipeline_service.ListTrainingPipelinesRequest(), + parent='parent_value', ) def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1138,14 +1090,17 @@ def test_list_training_pipelines_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1158,7 +1113,9 @@ def test_list_training_pipelines_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_training_pipelines(request={}) @@ -1166,16 +1123,18 @@ def test_list_training_pipelines_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) - + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in results) def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), "__call__" - ) as call: + type(client.transport.list_training_pipelines), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1184,14 +1143,17 @@ def test_list_training_pipelines_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1202,20 +1164,19 @@ def test_list_training_pipelines_pages(): RuntimeError, ) pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1224,14 +1185,17 @@ async def test_list_training_pipelines_async_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1242,25 +1206,25 @@ async def test_list_training_pipelines_async_pager(): RuntimeError, ) async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) - + assert all(isinstance(i, training_pipeline.TrainingPipeline) + for i in responses) @pytest.mark.asyncio async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_training_pipelines), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1269,14 +1233,17 @@ async def test_list_training_pipelines_async_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token="abc", + next_page_token='abc', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", + training_pipelines=[], + next_page_token='def', ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", + training_pipelines=[ + training_pipeline.TrainingPipeline(), + ], + next_page_token='ghi', ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1289,15 +1256,14 @@ async def test_list_training_pipelines_async_pages(): pages = [] async for page_ in (await client.list_training_pipelines(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest -): +def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1306,10 +1272,10 @@ def test_delete_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_training_pipeline(request) @@ -1328,22 +1294,23 @@ def test_delete_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_delete_training_pipeline_async(transport: str = "grpc_asyncio"): +async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.DeleteTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.DeleteTrainingPipelineRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_training_pipeline(request) @@ -1352,25 +1319,32 @@ async def test_delete_training_pipeline_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_training_pipeline_async_from_dict(): + await test_delete_training_pipeline_async(request_type=dict) + + def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_training_pipeline(request) @@ -1381,25 +1355,28 @@ def test_delete_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_training_pipeline), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_training_pipeline(request) @@ -1410,85 +1387,101 @@ async def test_delete_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_training_pipeline(name="name_value",) + client.delete_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), "__call__" - ) as call: + type(client.transport.delete_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_training_pipeline(name="name_value",) + response = await client.delete_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + pipeline_service.DeleteTrainingPipelineRequest(), + name='name_value', ) -def test_cancel_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest -): +def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1497,8 +1490,8 @@ def test_cancel_training_pipeline( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1519,19 +1512,20 @@ def test_cancel_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_cancel_training_pipeline_async(transport: str = "grpc_asyncio"): +async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CancelTrainingPipelineRequest): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CancelTrainingPipelineRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1541,24 +1535,31 @@ async def test_cancel_training_pipeline_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async_from_dict(): + await test_cancel_training_pipeline_async(request_type=dict) + + def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: call.return_value = None client.cancel_training_pipeline(request) @@ -1570,22 +1571,27 @@ def test_cancel_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_training_pipeline(request) @@ -1597,75 +1603,92 @@ async def test_cancel_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_training_pipeline(name="name_value",) + client.cancel_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', ) @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), "__call__" - ) as call: + type(client.transport.cancel_training_pipeline), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_training_pipeline(name="name_value",) + response = await client.cancel_training_pipeline( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + pipeline_service.CancelTrainingPipelineRequest(), + name='name_value', ) @@ -1676,7 +1699,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1695,7 +1719,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1723,16 +1748,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1740,8 +1762,13 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PipelineServiceGrpcTransport, + ) def test_pipeline_service_base_transport_error(): @@ -1749,15 +1776,13 @@ def test_pipeline_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_pipeline_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1766,12 +1791,12 @@ def test_pipeline_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_training_pipeline", - "get_training_pipeline", - "list_training_pipelines", - "delete_training_pipeline", - "cancel_training_pipeline", - ) + 'create_training_pipeline', + 'get_training_pipeline', + 'list_training_pipelines', + 'delete_training_pipeline', + 'cancel_training_pipeline', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1784,28 +1809,23 @@ def test_pipeline_service_base_transport(): def test_pipeline_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_pipeline_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport() @@ -1814,11 +1834,11 @@ def test_pipeline_service_base_transport_with_adc(): def test_pipeline_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PipelineServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -1826,75 +1846,62 @@ def test_pipeline_service_auth_adc(): def test_pipeline_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) - def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_pipeline_service_host_with_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_pipeline_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1903,7 +1910,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1919,30 +1926,27 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +def test_pipeline_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1959,7 +1963,9 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1968,12 +1974,16 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): def test_pipeline_service_grpc_lro_client(): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1981,34 +1991,36 @@ def test_pipeline_service_grpc_lro_client(): def test_pipeline_service_grpc_lro_async_client(): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client - def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) actual = PipelineServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } path = PipelineServiceClient.endpoint_path(**expected) @@ -2016,24 +2028,22 @@ def test_parse_endpoint_path(): actual = PipelineServiceClient.parse_endpoint_path(path) assert expected == actual - def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) + expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) actual = PipelineServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } path = PipelineServiceClient.model_path(**expected) @@ -2041,26 +2051,22 @@ def test_parse_model_path(): actual = PipelineServiceClient.parse_model_path(path) assert expected == actual - def test_training_pipeline_path(): project = "squid" location = "clam" training_pipeline = "whelk" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = PipelineServiceClient.training_pipeline_path( - project, location, training_pipeline - ) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) + actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", + } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2068,20 +2074,18 @@ def test_parse_training_pipeline_path(): actual = PipelineServiceClient.parse_training_pipeline_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = PipelineServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "mussel", + } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2089,18 +2093,18 @@ def test_parse_common_billing_account_path(): actual = PipelineServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = PipelineServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "nautilus", + } path = PipelineServiceClient.common_folder_path(**expected) @@ -2108,18 +2112,18 @@ def test_parse_common_folder_path(): actual = PipelineServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = PipelineServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "abalone", + } path = PipelineServiceClient.common_organization_path(**expected) @@ -2127,18 +2131,18 @@ def test_parse_common_organization_path(): actual = PipelineServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = PipelineServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "clam", + } path = PipelineServiceClient.common_project_path(**expected) @@ -2146,22 +2150,20 @@ def test_parse_common_project_path(): actual = PipelineServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = PipelineServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "oyster", + "location": "nudibranch", + } path = PipelineServiceClient.common_location_path(**expected) @@ -2173,19 +2175,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: transport_class = PipelineServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index bb0461f5ee..a9a2977768 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -35,12 +35,8 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceClient, -) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceClient from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -60,11 +56,7 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT def test__get_default_mtls_endpoint(): @@ -75,36 +67,17 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] -) +@pytest.mark.parametrize("client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient]) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -112,7 +85,7 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_specialist_pool_service_client_get_transport_class(): @@ -123,48 +96,29 @@ def test_specialist_pool_service_client_get_transport_class(): assert transport == transports.SpecialistPoolServiceGrpcTransport -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) -def test_specialist_pool_service_client_client_options( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) +def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=credentials.AnonymousCredentials() + ) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -180,7 +134,7 @@ def test_specialist_pool_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -196,7 +150,7 @@ def test_specialist_pool_service_client_client_options( # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -216,15 +170,13 @@ def test_specialist_pool_service_client_client_options( client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -237,66 +189,26 @@ def test_specialist_pool_service_client_client_options( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "true", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "false", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") +]) +@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) +@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): +def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): + with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): patched.return_value = None client = client_class(client_options=options) @@ -319,21 +231,11 @@ def test_specialist_pool_service_client_mtls_env_auto( # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -343,9 +245,7 @@ def test_specialist_pool_service_client_mtls_env_auto( is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + expected_ssl_channel_creds = ssl_credentials_mock.return_value patched.return_value = None client = client_class() @@ -360,17 +260,10 @@ def test_specialist_pool_service_client_mtls_env_auto( ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): + with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -385,27 +278,16 @@ def test_specialist_pool_service_client_mtls_env_auto( ) -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_scopes( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -418,28 +300,16 @@ def test_specialist_pool_service_client_client_options_scopes( client_info=transports.base.DEFAULT_CLIENT_INFO, ) - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), + (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") +]) +def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + with mock.patch.object(transport_class, '__init__') as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -454,12 +324,10 @@ def test_specialist_pool_service_client_client_options_credentials_file( def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" - ) as grpc_transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: grpc_transport.return_value = None client = SpecialistPoolServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} + client_options={'api_endpoint': 'squid.clam.whelk'} ) grpc_transport.assert_called_once_with( credentials=None, @@ -472,12 +340,10 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): +def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -486,10 +352,10 @@ def test_create_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.create_specialist_pool(request) @@ -508,22 +374,23 @@ def test_create_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_create_specialist_pool_async(transport: str = "grpc_asyncio"): +async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.CreateSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.CreateSpecialistPoolRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.create_specialist_pool(request) @@ -532,12 +399,17 @@ async def test_create_specialist_pool_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_create_specialist_pool_async_from_dict(): + await test_create_specialist_pool_async(request_type=dict) + + def test_create_specialist_pool_field_headers(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), @@ -546,13 +418,13 @@ def test_create_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.create_specialist_pool(request) @@ -563,7 +435,10 @@ def test_create_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -575,15 +450,13 @@ async def test_create_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.create_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.create_specialist_pool(request) @@ -594,7 +467,10 @@ async def test_create_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_create_specialist_pool_flattened(): @@ -604,16 +480,16 @@ def test_create_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -621,11 +497,9 @@ def test_create_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') def test_create_specialist_pool_flattened_error(): @@ -638,8 +512,8 @@ def test_create_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) @@ -651,19 +525,19 @@ async def test_create_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), "__call__" - ) as call: + type(client.transport.create_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) # Establish that the underlying call was made with the expected @@ -671,11 +545,9 @@ async def test_create_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') @pytest.mark.asyncio @@ -689,17 +561,15 @@ async def test_create_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + parent='parent_value', + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), ) -def test_get_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): +def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -708,15 +578,20 @@ def test_get_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", + name='name_value', + + display_name='display_name_value', + specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + + specialist_manager_emails=['specialist_manager_emails_value'], + + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + ) response = client.get_specialist_pool(request) @@ -728,17 +603,18 @@ def test_get_specialist_pool( assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() # Establish that the response is the type that we expect. + assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] def test_get_specialist_pool_from_dict(): @@ -746,29 +622,28 @@ def test_get_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_get_specialist_pool_async(transport: str = "grpc_asyncio"): +async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.GetSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.GetSpecialistPoolRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", - specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( + name='name_value', + display_name='display_name_value', + specialist_managers_count=2662, + specialist_manager_emails=['specialist_manager_emails_value'], + pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], + )) response = await client.get_specialist_pool(request) @@ -776,20 +651,25 @@ async def test_get_specialist_pool_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" + assert response.name == 'name_value' - assert response.display_name == "display_name_value" + assert response.display_name == 'display_name_value' assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + assert response.specialist_manager_emails == ['specialist_manager_emails_value'] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + + +@pytest.mark.asyncio +async def test_get_specialist_pool_async_from_dict(): + await test_get_specialist_pool_async(request_type=dict) def test_get_specialist_pool_field_headers(): @@ -800,12 +680,12 @@ def test_get_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: call.return_value = specialist_pool.SpecialistPool() client.get_specialist_pool(request) @@ -817,7 +697,10 @@ def test_get_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -829,15 +712,13 @@ async def test_get_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) + type(client.transport.get_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) await client.get_specialist_pool(request) @@ -848,7 +729,10 @@ async def test_get_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_get_specialist_pool_flattened(): @@ -858,21 +742,23 @@ def test_get_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_specialist_pool(name="name_value",) + client.get_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_get_specialist_pool_flattened_error(): @@ -884,7 +770,8 @@ def test_get_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', ) @@ -896,24 +783,24 @@ async def test_get_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), "__call__" - ) as call: + type(client.transport.get_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_specialist_pool(name="name_value",) + response = await client.get_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio @@ -926,16 +813,15 @@ async def test_get_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + specialist_pool_service.GetSpecialistPoolRequest(), + name='name_value', ) -def test_list_specialist_pools( - transport: str = "grpc", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): +def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -944,11 +830,12 @@ def test_list_specialist_pools( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", + next_page_token='next_page_token_value', + ) response = client.list_specialist_pools(request) @@ -960,9 +847,10 @@ def test_list_specialist_pools( assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' def test_list_specialist_pools_from_dict(): @@ -970,25 +858,24 @@ def test_list_specialist_pools_from_dict(): @pytest.mark.asyncio -async def test_list_specialist_pools_async(transport: str = "grpc_asyncio"): +async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.ListSpecialistPoolsRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.ListSpecialistPoolsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", - ) - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token='next_page_token_value', + )) response = await client.list_specialist_pools(request) @@ -996,12 +883,17 @@ async def test_list_specialist_pools_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - assert response.next_page_token == "next_page_token_value" + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_from_dict(): + await test_list_specialist_pools_async(request_type=dict) def test_list_specialist_pools_field_headers(): @@ -1012,12 +904,12 @@ def test_list_specialist_pools_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() client.list_specialist_pools(request) @@ -1029,7 +921,10 @@ def test_list_specialist_pools_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1041,15 +936,13 @@ async def test_list_specialist_pools_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" + request.parent = 'parent/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) + type(client.transport.list_specialist_pools), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) await client.list_specialist_pools(request) @@ -1060,7 +953,10 @@ async def test_list_specialist_pools_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'parent=parent/value', + ) in kw['metadata'] def test_list_specialist_pools_flattened(): @@ -1070,21 +966,23 @@ def test_list_specialist_pools_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_specialist_pools(parent="parent_value",) + client.list_specialist_pools( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' def test_list_specialist_pools_flattened_error(): @@ -1096,7 +994,8 @@ def test_list_specialist_pools_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', ) @@ -1108,24 +1007,24 @@ async def test_list_specialist_pools_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_specialist_pools(parent="parent_value",) + response = await client.list_specialist_pools( + parent='parent_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" + assert args[0].parent == 'parent_value' @pytest.mark.asyncio @@ -1138,17 +1037,20 @@ async def test_list_specialist_pools_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + specialist_pool_service.ListSpecialistPoolsRequest(), + parent='parent_value', ) def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1157,14 +1059,17 @@ def test_list_specialist_pools_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1177,7 +1082,9 @@ def test_list_specialist_pools_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), ) pager = client.list_specialist_pools(request={}) @@ -1185,16 +1092,18 @@ def test_list_specialist_pools_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) - + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in results) def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials, + ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), "__call__" - ) as call: + type(client.transport.list_specialist_pools), + '__call__') as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1203,14 +1112,17 @@ def test_list_specialist_pools_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1221,10 +1133,9 @@ def test_list_specialist_pools_pages(): RuntimeError, ) pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token - @pytest.mark.asyncio async def test_list_specialist_pools_async_pager(): client = SpecialistPoolServiceAsyncClient( @@ -1233,10 +1144,8 @@ async def test_list_specialist_pools_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1245,14 +1154,17 @@ async def test_list_specialist_pools_async_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1263,14 +1175,14 @@ async def test_list_specialist_pools_async_pager(): RuntimeError, ) async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == "abc" + assert async_pager.next_page_token == 'abc' responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) - + assert all(isinstance(i, specialist_pool.SpecialistPool) + for i in responses) @pytest.mark.asyncio async def test_list_specialist_pools_async_pages(): @@ -1280,10 +1192,8 @@ async def test_list_specialist_pools_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: + type(client.transport.list_specialist_pools), + '__call__', new_callable=mock.AsyncMock) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1292,14 +1202,17 @@ async def test_list_specialist_pools_async_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token="abc", + next_page_token='abc', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", + specialist_pools=[], + next_page_token='def', ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", + specialist_pools=[ + specialist_pool.SpecialistPool(), + ], + next_page_token='ghi', ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1312,16 +1225,14 @@ async def test_list_specialist_pools_async_pages(): pages = [] async for page_ in (await client.list_specialist_pools(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + for page_, token in zip(pages, ['abc','def','ghi', '']): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): +def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1330,10 +1241,10 @@ def test_delete_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.delete_specialist_pool(request) @@ -1352,22 +1263,23 @@ def test_delete_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_delete_specialist_pool_async(transport: str = "grpc_asyncio"): +async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.DeleteSpecialistPoolRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.delete_specialist_pool(request) @@ -1376,12 +1288,17 @@ async def test_delete_specialist_pool_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_specialist_pool_async_from_dict(): + await test_delete_specialist_pool_async(request_type=dict) + + def test_delete_specialist_pool_field_headers(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1390,13 +1307,13 @@ def test_delete_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.delete_specialist_pool(request) @@ -1407,7 +1324,10 @@ def test_delete_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1419,15 +1339,13 @@ async def test_delete_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" + request.name = 'name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.delete_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.delete_specialist_pool(request) @@ -1438,7 +1356,10 @@ async def test_delete_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + assert ( + 'x-goog-request-params', + 'name=name/value', + ) in kw['metadata'] def test_delete_specialist_pool_flattened(): @@ -1448,21 +1369,23 @@ def test_delete_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_specialist_pool(name="name_value",) + client.delete_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' def test_delete_specialist_pool_flattened_error(): @@ -1474,7 +1397,8 @@ def test_delete_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', ) @@ -1486,24 +1410,26 @@ async def test_delete_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), "__call__" - ) as call: + type(client.transport.delete_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_specialist_pool(name="name_value",) + response = await client.delete_specialist_pool( + name='name_value', + ) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" + assert args[0].name == 'name_value' @pytest.mark.asyncio @@ -1516,16 +1442,15 @@ async def test_delete_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + specialist_pool_service.DeleteSpecialistPoolRequest(), + name='name_value', ) -def test_update_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): +def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1534,10 +1459,10 @@ def test_update_specialist_pool( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") + call.return_value = operations_pb2.Operation(name='operations/spam') response = client.update_specialist_pool(request) @@ -1556,22 +1481,23 @@ def test_update_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_update_specialist_pool_async(transport: str = "grpc_asyncio"): +async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.UpdateSpecialistPoolRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) response = await client.update_specialist_pool(request) @@ -1580,12 +1506,17 @@ async def test_update_specialist_pool_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_update_specialist_pool_async_from_dict(): + await test_update_specialist_pool_async(request_type=dict) + + def test_update_specialist_pool_field_headers(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1594,13 +1525,13 @@ def test_update_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" + request.specialist_pool.name = 'specialist_pool.name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = operations_pb2.Operation(name='operations/op') client.update_specialist_pool(request) @@ -1612,9 +1543,9 @@ def test_update_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] @pytest.mark.asyncio @@ -1626,15 +1557,13 @@ async def test_update_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" + request.specialist_pool.name = 'specialist_pool.name/value' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) + type(client.transport.update_specialist_pool), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) await client.update_specialist_pool(request) @@ -1646,9 +1575,9 @@ async def test_update_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] + 'x-goog-request-params', + 'specialist_pool.name=specialist_pool.name/value', + ) in kw['metadata'] def test_update_specialist_pool_flattened(): @@ -1658,16 +1587,16 @@ def test_update_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1675,11 +1604,9 @@ def test_update_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) def test_update_specialist_pool_flattened_error(): @@ -1692,8 +1619,8 @@ def test_update_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @@ -1705,19 +1632,19 @@ async def test_update_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), "__call__" - ) as call: + type(client.transport.update_specialist_pool), + '__call__') as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") + call.return_value = operations_pb2.Operation(name='operations/op') call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") + operations_pb2.Operation(name='operations/spam') ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) # Establish that the underlying call was made with the expected @@ -1725,11 +1652,9 @@ async def test_update_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) @pytest.mark.asyncio @@ -1743,8 +1668,8 @@ async def test_update_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), + specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + update_mask=field_mask.FieldMask(paths=['paths_value']), ) @@ -1755,7 +1680,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, + credentials=credentials.AnonymousCredentials(), + transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1774,7 +1700,8 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, + client_options={"scopes": ["1", "2"]}, + transport=transport, ) @@ -1802,16 +1729,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport +]) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1822,7 +1746,10 @@ def test_transport_grpc_default(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), ) - assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) + assert isinstance( + client.transport, + transports.SpecialistPoolServiceGrpcTransport, + ) def test_specialist_pool_service_base_transport_error(): @@ -1830,15 +1757,13 @@ def test_specialist_pool_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", + credentials_file="credentials.json" ) def test_specialist_pool_service_base_transport(): # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" - ) as Transport: + with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: Transport.return_value = None transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1847,12 +1772,12 @@ def test_specialist_pool_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - "create_specialist_pool", - "get_specialist_pool", - "list_specialist_pools", - "delete_specialist_pool", - "update_specialist_pool", - ) + 'create_specialist_pool', + 'get_specialist_pool', + 'list_specialist_pools', + 'delete_specialist_pool', + 'update_specialist_pool', + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1865,28 +1790,23 @@ def test_specialist_pool_service_base_transport(): def test_specialist_pool_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", + credentials_file="credentials.json", + quota_project_id="octopus", ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), + load_creds.assert_called_once_with("credentials.json", scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), quota_project_id="octopus", ) def test_specialist_pool_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: + with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport() @@ -1895,11 +1815,11 @@ def test_specialist_pool_service_base_transport_with_adc(): def test_specialist_pool_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) SpecialistPoolServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id=None, ) @@ -1907,75 +1827,62 @@ def test_specialist_pool_service_auth_adc(): def test_specialist_pool_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), + transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") + adc.assert_called_once_with(scopes=( + 'https://www.googleapis.com/auth/cloud-platform',), quota_project_id="octopus", ) - def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), ) - assert client.transport._host == "aiplatform.googleapis.com:443" + assert client.transport._host == 'aiplatform.googleapis.com:443' def test_specialist_pool_service_host_with_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), + client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), ) - assert client.transport._host == "aiplatform.googleapis.com:8000" + assert client.transport._host == 'aiplatform.googleapis.com:8000' def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.insecure_channel('http://localhost/') # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, + host="squid.clam.whelk", + channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class, + transport_class ): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1984,7 +1891,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: + with mock.patch.object(auth, 'default') as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2000,30 +1907,27 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) -def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): +@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +def test_specialist_pool_service_transport_channel_mtls_with_adc( + transport_class +): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: + with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2040,7 +1944,9 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=( + 'https://www.googleapis.com/auth/cloud-platform', + ), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -2049,12 +1955,16 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class def test_specialist_pool_service_grpc_lro_client(): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", + credentials=credentials.AnonymousCredentials(), + transport='grpc', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2062,36 +1972,36 @@ def test_specialist_pool_service_grpc_lro_client(): def test_specialist_pool_service_grpc_lro_async_client(): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + credentials=credentials.AnonymousCredentials(), + transport='grpc_asyncio', ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client - def test_specialist_pool_path(): project = "squid" location = "clam" specialist_pool = "whelk" - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) - actual = SpecialistPoolServiceClient.specialist_pool_path( - project, location, specialist_pool - ) + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) + actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) assert expected == actual def test_parse_specialist_pool_path(): expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", + } path = SpecialistPoolServiceClient.specialist_pool_path(**expected) @@ -2099,20 +2009,18 @@ def test_parse_specialist_pool_path(): actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) assert expected == actual - def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format( - billing_account=billing_account, - ) + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", + "billing_account": "mussel", + } path = SpecialistPoolServiceClient.common_billing_account_path(**expected) @@ -2120,18 +2028,18 @@ def test_parse_common_billing_account_path(): actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) assert expected == actual - def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder,) + expected = "folders/{folder}".format(folder=folder, ) actual = SpecialistPoolServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", + "folder": "nautilus", + } path = SpecialistPoolServiceClient.common_folder_path(**expected) @@ -2139,18 +2047,18 @@ def test_parse_common_folder_path(): actual = SpecialistPoolServiceClient.parse_common_folder_path(path) assert expected == actual - def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization,) + expected = "organizations/{organization}".format(organization=organization, ) actual = SpecialistPoolServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", + "organization": "abalone", + } path = SpecialistPoolServiceClient.common_organization_path(**expected) @@ -2158,18 +2066,18 @@ def test_parse_common_organization_path(): actual = SpecialistPoolServiceClient.parse_common_organization_path(path) assert expected == actual - def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project,) + expected = "projects/{project}".format(project=project, ) actual = SpecialistPoolServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", + "project": "clam", + } path = SpecialistPoolServiceClient.common_project_path(**expected) @@ -2177,22 +2085,20 @@ def test_parse_common_project_path(): actual = SpecialistPoolServiceClient.parse_common_project_path(path) assert expected == actual - def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format( - project=project, location=location, - ) + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) actual = SpecialistPoolServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", + "project": "oyster", + "location": "nudibranch", + } path = SpecialistPoolServiceClient.common_location_path(**expected) @@ -2204,19 +2110,17 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: + with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: transport_class = SpecialistPoolServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, + credentials=credentials.AnonymousCredentials(), + client_info=client_info, ) prep.assert_called_once_with(client_info) From 57aa05ffa98e7c216000775e68e33c43ac001c93 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 3 Nov 2020 12:20:32 -0800 Subject: [PATCH 04/12] update beta --- docs/conf.py | 7 +- .../services/job_service/async_client.py | 1 + .../services/job_service/client.py | 1 + .../types/batch_prediction_job.py | 24 + .../types/explanation_metadata.py | 10 +- .../cloud/aiplatform_v1beta1/types/model.py | 117 +- noxfile.py | 30 +- setup.py | 40 +- synth.metadata | 4 +- .../test_dataset_service.py | 2168 +++++----- .../test_endpoint_service.py | 1537 ++++---- .../aiplatform_v1beta1/test_job_service.py | 3501 ++++++++--------- .../test_migration_service.py | 907 +++-- .../aiplatform_v1beta1/test_model_service.py | 2269 +++++------ .../test_pipeline_service.py | 1230 +++--- .../test_specialist_pool_service.py | 1097 +++--- 16 files changed, 6245 insertions(+), 6698 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index b45ecd8682..effa4a8f1f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -347,12 +347,9 @@ intersphinx_mapping = { "python": ("http://python.readthedocs.org/en/latest/", None), "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), - "google.api_core": ( - "https://googleapis.dev/python/google-api-core/latest/", - None, - ), + "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,), "grpc": ("https://grpc.io/grpc/python/", None), - + "proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None), } diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index da6fafd965..d988c81d3c 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -42,6 +42,7 @@ from google.cloud.aiplatform_v1beta1.types import ( data_labeling_job as gca_data_labeling_job, ) +from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import ( hyperparameter_tuning_job as gca_hyperparameter_tuning_job, diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index cf840174c5..a1eb7c38ce 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -46,6 +46,7 @@ from google.cloud.aiplatform_v1beta1.types import ( data_labeling_job as gca_data_labeling_job, ) +from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import ( hyperparameter_tuning_job as gca_hyperparameter_tuning_job, diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 64892b8271..625bf83155 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -21,6 +21,7 @@ from google.cloud.aiplatform_v1beta1.types import ( completion_stats as gca_completion_stats, ) +from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources @@ -114,6 +115,25 @@ class BatchPredictionJob(proto.Message): object. - ``csv``: Generating explanations for CSV format is not supported. + explanation_spec (~.explanation.ExplanationSpec): + Explanation configuration for this BatchPredictionJob. Can + only be specified if + ``generate_explanation`` + is set to ``true``. It's invalid to specified it with + generate_explanation set to false or unset. + + This value overrides the value of + ``Model.explanation_spec``. + All fields of + ``explanation_spec`` + are optional in the request. If a field of + ``explanation_spec`` + is not populated, the value of the same field of + ``Model.explanation_spec`` + is inherited. The corresponding + ``Model.explanation_spec`` + must be populated, otherwise explanation for this Model is + not allowed. output_info (~.batch_prediction_job.BatchPredictionJob.OutputInfo): Output only. Information further describing the output of this job. @@ -326,6 +346,10 @@ class OutputInfo(proto.Message): generate_explanation = proto.Field(proto.BOOL, number=23) + explanation_spec = proto.Field( + proto.MESSAGE, number=25, message=explanation.ExplanationSpec, + ) + output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 38520669ef..7261c064f8 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -47,7 +47,7 @@ class ExplanationMetadata(proto.Message): be keyed by this key (if not grouped with another feature). For custom images, the key must match with the key in - ``instance``[]. + ``instance``. outputs (Sequence[~.explanation_metadata.ExplanationMetadata.OutputsEntry]): Required. Map from output names to output metadata. @@ -183,10 +183,10 @@ class FeatureValueDomain(proto.Message): (with mean = 0 and stddev = 1) was obtained. Attributes: - min_ (float): + min_value (float): The minimum permissible value for this feature. - max_ (float): + max_value (float): The maximum permissible value for this feature. original_mean (float): @@ -199,9 +199,9 @@ class FeatureValueDomain(proto.Message): deviation of the domain prior to normalization. """ - min_ = proto.Field(proto.FLOAT, number=1) + min_value = proto.Field(proto.FLOAT, number=1) - max_ = proto.Field(proto.FLOAT, number=2) + max_value = proto.Field(proto.FLOAT, number=2) original_mean = proto.Field(proto.FLOAT, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 7fa9130909..21e8c41034 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -70,7 +70,7 @@ class Model(proto.Message): supported_export_formats (Sequence[~.model.Model.ExportFormat]): Output only. The formats in which this Model may be exported. If empty, this Model is not - avaiable for export. + available for export. training_pipeline (str): Output only. The resource name of the TrainingPipeline that uploaded this Model, if @@ -390,20 +390,18 @@ class PredictSchemata(proto.Message): class ModelContainerSpec(proto.Message): r"""Specification of a container for serving predictions. This message - is a subset of the [Kubernetes Container v1 core - - specification](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). + is a subset of the Kubernetes Container v1 core + `specification `__. Attributes: image_uri (str): Required. Immutable. URI of the Docker image to be used as the custom container for serving predictions. This URI must identify an image in Artifact Registry or Container - Registry. Learn more about the [container publishing - - requirements](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#publishing), - including permissions requirements for the AI Platform - Service Agent. + Registry. Learn more about the container publishing + requirements, including permissions requirements for the AI + Platform Service Agent, + `here `__. The container image is ingested upon ``ModelService.UploadModel``, @@ -411,14 +409,12 @@ class ModelContainerSpec(proto.Message): used. To learn about the requirements for the Docker image itself, - read [Custom container - - requirements](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements). + see `Custom container + requirements `__. command (Sequence[str]): Immutable. Specifies the command that runs when the container starts. This overrides the container's - - [``ENTRYPOINT``](https://docs.docker.com/engine/reference/builder/#entrypoint). + `ENTRYPOINT `__. Specify this field as an array of executable and arguments, similar to a Docker ``ENTRYPOINT``'s "exec" form, not its "shell" form. @@ -430,23 +426,20 @@ class ModelContainerSpec(proto.Message): ```CMD`` `__, if either exists. If this field is not specified and the container does not have an ``ENTRYPOINT``, then refer to the - [Docker documentation about how ``CMD`` and ``ENTRYPOINT`` - - interact](https://docs.docker.com/engine/reference/builder/#understand-how-cmd-and-entrypoint-interact). + Docker documentation about how ``CMD`` and ``ENTRYPOINT`` + `interact `__. If you specify this field, then you can also specify the ``args`` field to provide additional arguments for this command. However, if you specify this field, then the - container's ``CMD`` is ignored. See the [Kubernetes - documentation about how the ``command`` and ``args`` fields - interact with a container's ``ENTRYPOINT`` and + container's ``CMD`` is ignored. See the `Kubernetes + documentation `__ about how + the ``command`` and ``args`` fields interact with a + container's ``ENTRYPOINT`` and ``CMD``. - ``CMD``](https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#notes). - - In this field, you can reference [environment variables set + In this field, you can reference environment variables `set by AI - - Platform](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables) + Platform `__ and environment variables set in the ``env`` field. You cannot reference environment variables set in the @@ -457,10 +450,9 @@ class ModelContainerSpec(proto.Message): cannot be resolved, the reference in the input string is used unchanged. To avoid variable expansion, you can escape this syntax with ``$$``; for example: $$(VARIABLE_NAME) This - field corresponds to the ``command`` field of the - [Kubernetes Containers v1 core - - API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). + field corresponds to the ``command`` field of the Kubernetes + Containers `v1 core + API `__. args (Sequence[str]): Immutable. Specifies arguments for the command that runs when the container starts. This overrides the container's @@ -471,25 +463,21 @@ class ModelContainerSpec(proto.Message): If you don't specify this field but do specify the ``command`` field, then the command from the ``command`` field runs - without any additional arguments. See the [Kubernetes - documentation about how the ``command`` and ``args`` fields - interact with a container's ``ENTRYPOINT`` and - - ``CMD``](https://kubernetes.io/docs/tasks/inject-data-application/define-command-argument-container/#notes). + without any additional arguments. See the `Kubernetes + documentation `__ about how + the ``command`` and ``args`` fields interact with a + container's ``ENTRYPOINT`` and ``CMD``. If you don't specify this field and don't specify the ``command`` field, then the container's ```ENTRYPOINT`` `__ and ``CMD`` determine what runs based on their default - behavior. See the [Docker documentation about how ``CMD`` - and ``ENTRYPOINT`` - - interact](https://docs.docker.com/engine/reference/builder/#understand-how-cmd-and-entrypoint-interact). + behavior. See the Docker documentation about how ``CMD`` and + ``ENTRYPOINT`` `interact `__. - In this field, you can reference [environment variables set + In this field, you can reference environment variables `set by AI - - Platform](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables) + Platform `__ and environment variables set in the ``env`` field. You cannot reference environment variables set in the @@ -500,10 +488,9 @@ class ModelContainerSpec(proto.Message): cannot be resolved, the reference in the input string is used unchanged. To avoid variable expansion, you can escape this syntax with ``$$``; for example: $$(VARIABLE_NAME) This - field corresponds to the ``args`` field of the [Kubernetes - Containers v1 core - - API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). + field corresponds to the ``args`` field of the Kubernetes + Containers `v1 core + API `__. env (Sequence[~.env_var.EnvVar]): Immutable. List of environment variables to set in the container. After the container starts running, code running @@ -535,16 +522,14 @@ class ModelContainerSpec(proto.Message): then the expansion does not occur. This field corresponds to the ``env`` field of the - [Kubernetes Containers v1 core - - API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). + Kubernetes Containers `v1 core + API `__. ports (Sequence[~.model.Port]): Immutable. List of ports to expose from the container. AI Platform sends any prediction requests that it receives to the first port on this list. AI Platform also sends - [liveness and health - - checks](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#health) to + `liveness and health + checks `__ to this port. If you do not specify this field, it defaults to following @@ -560,9 +545,8 @@ class ModelContainerSpec(proto.Message): AI Platform does not use ports other than the first one listed. This field corresponds to the ``ports`` field of the - [Kubernetes Containers v1 core - - API](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.18/#container-v1-core). + Kubernetes Containers `v1 core + API `__. predict_route (str): Immutable. HTTP path on the container to send prediction requests to. AI Platform forwards requests sent using @@ -589,24 +573,21 @@ class ModelContainerSpec(proto.Message): the Endpoint.name][] field of the Endpoint where this Model has been deployed. (AI Platform makes this value available to your container code as the - [``AIP_ENDPOINT_ID`` environment - - variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) + ```AIP_ENDPOINT_ID`` `__ + environment variable.) - DEPLOYED_MODEL: ``DeployedModel.id`` of the ``DeployedModel``. (AI Platform makes this value available to your container code as the - [``AIP_DEPLOYED_MODEL_ID`` environment - - variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) + ```AIP_DEPLOYED_MODEL_ID`` environment + variable `__.) health_route (str): Immutable. HTTP path on the container to send health checkss to. AI Platform intermittently sends GET requests to this path on the container's IP address and port to check that - the container is healthy. Read more about [health - - checks](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#checks). + the container is healthy. Read more about `health + checks `__. For example, if you set this field to ``/bar``, then AI Platform intermittently sends a GET request to the following @@ -625,17 +606,15 @@ class ModelContainerSpec(proto.Message): the Endpoint.name][] field of the Endpoint where this Model has been deployed. (AI Platform makes this value available to your container code as the - [``AIP_ENDPOINT_ID`` environment - - variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) + ```AIP_ENDPOINT_ID`` `__ + environment variable.) - DEPLOYED_MODEL: ``DeployedModel.id`` of the ``DeployedModel``. (AI Platform makes this value available to your container code as the - [``AIP_DEPLOYED_MODEL_ID`` environment - - variable](https://cloud.google.com/ai-platform-unified/docs/predictions/custom-container-requirements#aip-variables).) + ```AIP_DEPLOYED_MODEL_ID`` `__ + environment variable.) """ image_uri = proto.Field(proto.STRING, number=1) diff --git a/noxfile.py b/noxfile.py index 4f20df5c36..4b2538e5c9 100644 --- a/noxfile.py +++ b/noxfile.py @@ -40,9 +40,7 @@ def lint(session): """ session.install("flake8", BLACK_VERSION) session.run( - "black", - "--check", - *BLACK_PATHS, + "black", "--check", *BLACK_PATHS, ) session.run("flake8", "google", "tests") @@ -59,8 +57,7 @@ def blacken(session): """ session.install(BLACK_VERSION) session.run( - "black", - *BLACK_PATHS, + "black", *BLACK_PATHS, ) @@ -74,8 +71,10 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. session.install("asyncmock", "pytest-asyncio") - - session.install("mock", "pytest", "pytest-cov") + + session.install( + "mock", "pytest", "pytest-cov", + ) session.install("-e", ".") # Run py.test against the unit tests. @@ -93,6 +92,7 @@ def default(session): *session.posargs, ) + @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) def unit(session): """Run the unit test suite.""" @@ -106,7 +106,7 @@ def system(session): system_test_folder_path = os.path.join("tests", "system") # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. - if os.environ.get("RUN_SYSTEM_TESTS", "true") == 'false': + if os.environ.get("RUN_SYSTEM_TESTS", "true") == "false": session.skip("RUN_SYSTEM_TESTS is set to false, skipping") # Sanity check: Only run tests if the environment variable is set. if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): @@ -123,10 +123,11 @@ def system(session): # Install all test dependencies, then install this package into the # virtualenv's dist-packages. - session.install("mock", "pytest", "google-cloud-testutils", ) + session.install( + "mock", "pytest", "google-cloud-testutils", + ) session.install("-e", ".") - # Run py.test against the system tests. if system_test_exists: session.run("py.test", "--quiet", system_test_path, *session.posargs) @@ -134,7 +135,6 @@ def system(session): session.run("py.test", "--quiet", system_test_folder_path, *session.posargs) - @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. @@ -147,16 +147,18 @@ def cover(session): session.run("coverage", "erase") + @nox.session(python=DEFAULT_PYTHON_VERSION) def docs(session): """Build the docs for this library.""" - session.install('-e', '.') - session.install('sphinx', 'alabaster', 'recommonmark') + session.install("-e", ".") + session.install("sphinx", "alabaster", "recommonmark") - shutil.rmtree(os.path.join('docs', '_build'), ignore_errors=True) + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( "sphinx-build", + "-W", # warnings as errors "-T", # show full traceback on exception "-N", # no colors "-b", diff --git a/setup.py b/setup.py index 8f159e0dc7..82468cded3 100644 --- a/setup.py +++ b/setup.py @@ -19,32 +19,30 @@ setuptools.setup( - name='google-cloud-aiplatform', - version='0.3.0', + name="google-cloud-aiplatform", + version="0.3.0", packages=setuptools.PEP420PackageFinder.find(), - namespace_packages=('google', 'google.cloud'), - platforms='Posix; MacOS X; Windows', + namespace_packages=("google", "google.cloud"), + platforms="Posix; MacOS X; Windows", include_package_data=True, install_requires=( - 'google-api-core[grpc] >= 1.22.2, < 2.0.0dev', - 'libcst >= 0.2.5', - 'proto-plus >= 1.4.0', - 'mock >= 4.0.2', - 'google-cloud-storage >= 1.26.0', + "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", + "libcst >= 0.2.5", + "proto-plus >= 1.4.0", + "mock >= 4.0.2", + "google-cloud-storage >= 1.26.0", ), - python_requires='>=3.6', - scripts=[ - 'scripts/fixup_aiplatform_v1beta1_keywords.py', - ], + python_requires=">=3.6", + scripts=["scripts/fixup_aiplatform_v1beta1_keywords.py",], classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Topic :: Internet', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Topic :: Internet", + "Topic :: Software Development :: Libraries :: Python Modules", ], zip_safe=False, ) diff --git a/synth.metadata b/synth.metadata index 866be8e22e..bc34821e1f 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,14 +4,14 @@ "git": { "name": ".", "remote": "https://github.com/dizcology/python-aiplatform.git", - "sha": "b428c3bd3c19861cb431595f71aa43123e0dd1af" + "sha": "60263c04ffd04dabd7cc95c138b9f1c87566208c" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "f68649c5f26bcff6817c6d21e90dac0fc71fef8e" + "sha": "ba9918cd22874245b55734f57470c719b577e591" } } ], diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index 08020beb3c..51022d9fb7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.dataset_service import DatasetServiceClient +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceClient, +) from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.services.dataset_service import transports from google.cloud.aiplatform_v1beta1.types import annotation @@ -62,7 +66,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -73,17 +81,35 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [DatasetServiceClient, DatasetServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] +) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -91,7 +117,7 @@ def test_dataset_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_client_get_transport_class(): @@ -102,29 +128,44 @@ def test_dataset_service_client_get_transport_class(): assert transport == transports.DatasetServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) -def test_dataset_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -140,7 +181,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -156,7 +197,7 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -176,13 +217,15 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -195,26 +238,56 @@ def test_dataset_service_client_client_options(client_class, transport_class, tr client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(DatasetServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceClient)) -@mock.patch.object(DatasetServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DatasetServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_dataset_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -237,11 +310,21 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -251,7 +334,9 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -266,10 +351,17 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -284,16 +376,23 @@ def test_dataset_service_client_mtls_env_auto(client_class, transport_class, tra ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_dataset_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -306,16 +405,24 @@ def test_dataset_service_client_client_options_scopes(client_class, transport_cl client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - (DatasetServiceAsyncClient, transports.DatasetServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_dataset_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -330,10 +437,12 @@ def test_dataset_service_client_client_options_credentials_file(client_class, tr def test_dataset_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = DatasetServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -346,10 +455,11 @@ def test_dataset_service_client_client_options_from_dict(): ) -def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.CreateDatasetRequest): +def test_create_dataset( + transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -357,11 +467,9 @@ def test_create_dataset(transport: str = 'grpc', request_type=dataset_service.Cr request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_dataset(request) @@ -380,10 +488,11 @@ def test_create_dataset_from_dict(): @pytest.mark.asyncio -async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.CreateDatasetRequest): +async def test_create_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -391,12 +500,10 @@ async def test_create_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_dataset(request) @@ -417,20 +524,16 @@ async def test_create_dataset_async_from_dict(): def test_create_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_dataset(request) @@ -441,28 +544,23 @@ def test_create_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.CreateDatasetRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_dataset(request) @@ -473,29 +571,21 @@ async def test_create_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -503,47 +593,40 @@ def test_create_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") def test_create_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) @pytest.mark.asyncio async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_dataset( - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -551,31 +634,30 @@ async def test_create_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") @pytest.mark.asyncio async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_dataset( dataset_service.CreateDatasetRequest(), - parent='parent_value', - dataset=gca_dataset.Dataset(name='name_value'), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), ) -def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDatasetRequest): +def test_get_dataset( + transport: str = "grpc", request_type=dataset_service.GetDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -583,19 +665,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.get_dataset(request) @@ -610,13 +686,13 @@ def test_get_dataset(transport: str = 'grpc', request_type=dataset_service.GetDa assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_dataset_from_dict(): @@ -624,10 +700,11 @@ def test_get_dataset_from_dict(): @pytest.mark.asyncio -async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetDatasetRequest): +async def test_get_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -635,16 +712,16 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.get_dataset(request) @@ -657,13 +734,13 @@ async def test_get_dataset_async(transport: str = 'grpc_asyncio', request_type=d # Establish that the response is the type that we expect. assert isinstance(response, dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -672,19 +749,15 @@ async def test_get_dataset_async_from_dict(): def test_get_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = dataset.Dataset() client.get_dataset(request) @@ -696,27 +769,20 @@ def test_get_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) await client.get_dataset(request) @@ -728,99 +794,79 @@ async def test_get_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_dataset( - name='name_value', - ) + client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset.Dataset() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_dataset( - name='name_value', - ) + response = await client.get_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_dataset( - dataset_service.GetDatasetRequest(), - name='name_value', + dataset_service.GetDatasetRequest(), name="name_value", ) -def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.UpdateDatasetRequest): +def test_update_dataset( + transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -828,19 +874,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset( - name='name_value', - - display_name='display_name_value', - - metadata_schema_uri='metadata_schema_uri_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", ) response = client.update_dataset(request) @@ -855,13 +895,13 @@ def test_update_dataset(transport: str = 'grpc', request_type=dataset_service.Up assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_dataset_from_dict(): @@ -869,10 +909,11 @@ def test_update_dataset_from_dict(): @pytest.mark.asyncio -async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.UpdateDatasetRequest): +async def test_update_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -880,16 +921,16 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset( - name='name_value', - display_name='display_name_value', - metadata_schema_uri='metadata_schema_uri_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) response = await client.update_dataset(request) @@ -902,13 +943,13 @@ async def test_update_dataset_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, gca_dataset.Dataset) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -917,19 +958,15 @@ async def test_update_dataset_async_from_dict(): def test_update_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = gca_dataset.Dataset() client.update_dataset(request) @@ -941,27 +978,22 @@ def test_update_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.UpdateDatasetRequest() - request.dataset.name = 'dataset.name/value' + request.dataset.name = "dataset.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) await client.update_dataset(request) @@ -973,29 +1005,24 @@ async def test_update_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'dataset.name=dataset.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] def test_update_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1003,36 +1030,30 @@ def test_update_dataset_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_dataset.Dataset() @@ -1040,8 +1061,8 @@ async def test_update_dataset_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_dataset( - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1049,31 +1070,30 @@ async def test_update_dataset_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name='name_value') + assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_dataset( dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.ListDatasetsRequest): +def test_list_datasets( + transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1081,13 +1101,10 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_datasets(request) @@ -1102,7 +1119,7 @@ def test_list_datasets(transport: str = 'grpc', request_type=dataset_service.Lis assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_datasets_from_dict(): @@ -1110,10 +1127,11 @@ def test_list_datasets_from_dict(): @pytest.mark.asyncio -async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDatasetsRequest): +async def test_list_datasets_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1121,13 +1139,13 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_datasets(request) @@ -1140,7 +1158,7 @@ async def test_list_datasets_async(transport: str = 'grpc_asyncio', request_type # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDatasetsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1149,19 +1167,15 @@ async def test_list_datasets_async_from_dict(): def test_list_datasets_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: call.return_value = dataset_service.ListDatasetsResponse() client.list_datasets(request) @@ -1173,28 +1187,23 @@ def test_list_datasets_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDatasetsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) await client.list_datasets(request) @@ -1205,138 +1214,100 @@ async def test_list_datasets_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_datasets_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_datasets( - parent='parent_value', - ) + client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_datasets_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDatasetsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDatasetsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_datasets( - parent='parent_value', - ) + response = await client.list_datasets(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_datasets( - dataset_service.ListDatasetsRequest(), - parent='parent_value', + dataset_service.ListDatasetsRequest(), parent="parent_value", ) def test_list_datasets_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_datasets(request={}) @@ -1344,147 +1315,102 @@ def test_list_datasets_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in results) + assert all(isinstance(i, dataset.Dataset) for i in results) + def test_list_datasets_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_datasets), - '__call__') as call: + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) - for i in responses) + assert all(isinstance(i, dataset.Dataset) for i in responses) + @pytest.mark.asyncio async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_datasets), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - dataset.Dataset(), - ], - next_page_token='abc', - ), - dataset_service.ListDatasetsResponse( - datasets=[], - next_page_token='def', + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - ], - next_page_token='ghi', + datasets=[dataset.Dataset(),], next_page_token="ghi", ), dataset_service.ListDatasetsResponse( - datasets=[ - dataset.Dataset(), - dataset.Dataset(), - ], + datasets=[dataset.Dataset(), dataset.Dataset(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_datasets(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.DeleteDatasetRequest): +def test_delete_dataset( + transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1492,11 +1418,9 @@ def test_delete_dataset(transport: str = 'grpc', request_type=dataset_service.De request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_dataset(request) @@ -1515,10 +1439,11 @@ def test_delete_dataset_from_dict(): @pytest.mark.asyncio -async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_type=dataset_service.DeleteDatasetRequest): +async def test_delete_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1526,12 +1451,10 @@ async def test_delete_dataset_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_dataset(request) @@ -1552,20 +1475,16 @@ async def test_delete_dataset_async_from_dict(): def test_delete_dataset_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_dataset(request) @@ -1576,28 +1495,23 @@ def test_delete_dataset_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.DeleteDatasetRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_dataset(request) @@ -1608,101 +1522,81 @@ async def test_delete_dataset_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_dataset_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_dataset( - name='name_value', - ) + client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_dataset_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_dataset), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_dataset( - name='name_value', - ) + response = await client.delete_dataset(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), - name='name_value', + dataset_service.DeleteDatasetRequest(), name="name_value", ) -def test_import_data(transport: str = 'grpc', request_type=dataset_service.ImportDataRequest): +def test_import_data( + transport: str = "grpc", request_type=dataset_service.ImportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1710,11 +1604,9 @@ def test_import_data(transport: str = 'grpc', request_type=dataset_service.Impor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.import_data(request) @@ -1733,10 +1625,11 @@ def test_import_data_from_dict(): @pytest.mark.asyncio -async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ImportDataRequest): +async def test_import_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1744,12 +1637,10 @@ async def test_import_data_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.import_data(request) @@ -1770,20 +1661,16 @@ async def test_import_data_async_from_dict(): def test_import_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.import_data(request) @@ -1794,28 +1681,23 @@ def test_import_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ImportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.import_data(request) @@ -1826,29 +1708,24 @@ async def test_import_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_import_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -1856,47 +1733,47 @@ def test_import_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] def test_import_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) @pytest.mark.asyncio async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.import_data), - '__call__') as call: + with mock.patch.object(type(client.transport.import_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.import_data( - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) # Establish that the underlying call was made with the expected @@ -1904,31 +1781,34 @@ async def test_import_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].import_configs == [dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))] + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] @pytest.mark.asyncio async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.import_data( dataset_service.ImportDataRequest(), - name='name_value', - import_configs=[dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=['uris_value']))], + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], ) -def test_export_data(transport: str = 'grpc', request_type=dataset_service.ExportDataRequest): +def test_export_data( + transport: str = "grpc", request_type=dataset_service.ExportDataRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1936,11 +1816,9 @@ def test_export_data(transport: str = 'grpc', request_type=dataset_service.Expor request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_data(request) @@ -1959,10 +1837,11 @@ def test_export_data_from_dict(): @pytest.mark.asyncio -async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ExportDataRequest): +async def test_export_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1970,12 +1849,10 @@ async def test_export_data_async(transport: str = 'grpc_asyncio', request_type=d request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_data(request) @@ -1996,20 +1873,16 @@ async def test_export_data_async_from_dict(): def test_export_data_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_data(request) @@ -2020,28 +1893,23 @@ def test_export_data_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ExportDataRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_data(request) @@ -2052,29 +1920,26 @@ async def test_export_data_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_export_data_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2082,47 +1947,53 @@ def test_export_data_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) def test_export_data_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) @pytest.mark.asyncio async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_data), - '__call__') as call: + with mock.patch.object(type(client.transport.export_data), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_data( - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) # Establish that the underlying call was made with the expected @@ -2130,31 +2001,38 @@ async def test_export_data_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')) + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) @pytest.mark.asyncio async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_data( dataset_service.ExportDataRequest(), - name='name_value', - export_config=dataset.ExportDataConfig(gcs_destination=io.GcsDestination(output_uri_prefix='output_uri_prefix_value')), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), ) -def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.ListDataItemsRequest): +def test_list_data_items( + transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2162,13 +2040,10 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_items(request) @@ -2183,7 +2058,7 @@ def test_list_data_items(transport: str = 'grpc', request_type=dataset_service.L assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_items_from_dict(): @@ -2191,10 +2066,11 @@ def test_list_data_items_from_dict(): @pytest.mark.asyncio -async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListDataItemsRequest): +async def test_list_data_items_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2202,13 +2078,13 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_items(request) @@ -2221,7 +2097,7 @@ async def test_list_data_items_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataItemsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2230,19 +2106,15 @@ async def test_list_data_items_async_from_dict(): def test_list_data_items_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: call.return_value = dataset_service.ListDataItemsResponse() client.list_data_items(request) @@ -2254,28 +2126,23 @@ def test_list_data_items_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListDataItemsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) await client.list_data_items(request) @@ -2286,104 +2153,81 @@ async def test_list_data_items_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_items_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_items( - parent='parent_value', - ) + client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_items_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListDataItemsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListDataItemsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_items( - parent='parent_value', - ) + response = await client.list_data_items(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_items( - dataset_service.ListDataItemsRequest(), - parent='parent_value', + dataset_service.ListDataItemsRequest(), parent="parent_value", ) def test_list_data_items_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2392,32 +2236,23 @@ def test_list_data_items_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_items(request={}) @@ -2425,18 +2260,14 @@ def test_list_data_items_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in results) + assert all(isinstance(i, data_item.DataItem) for i in results) + def test_list_data_items_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_data_items), - '__call__') as call: + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2445,40 +2276,32 @@ def test_list_data_items_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2487,46 +2310,37 @@ async def test_list_data_items_async_pager(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) - for i in responses) + assert all(isinstance(i, data_item.DataItem) for i in responses) + @pytest.mark.asyncio async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_items), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListDataItemsResponse( @@ -2535,37 +2349,31 @@ async def test_list_data_items_async_pages(): data_item.DataItem(), data_item.DataItem(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListDataItemsResponse( - data_items=[], - next_page_token='def', + data_items=[], next_page_token="def", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - ], - next_page_token='ghi', + data_items=[data_item.DataItem(),], next_page_token="ghi", ), dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - ], + data_items=[data_item.DataItem(), data_item.DataItem(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_data_items(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_service.GetAnnotationSpecRequest): +def test_get_annotation_spec( + transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2574,16 +2382,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec( - name='name_value', - - display_name='display_name_value', - - etag='etag_value', - + name="name_value", display_name="display_name_value", etag="etag_value", ) response = client.get_annotation_spec(request) @@ -2598,11 +2401,11 @@ def test_get_annotation_spec(transport: str = 'grpc', request_type=dataset_servi assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_annotation_spec_from_dict(): @@ -2610,10 +2413,12 @@ def test_get_annotation_spec_from_dict(): @pytest.mark.asyncio -async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', request_type=dataset_service.GetAnnotationSpecRequest): +async def test_get_annotation_spec_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.GetAnnotationSpecRequest, +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2622,14 +2427,14 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec( - name='name_value', - display_name='display_name_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + ) response = await client.get_annotation_spec(request) @@ -2642,11 +2447,11 @@ async def test_get_annotation_spec_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -2655,19 +2460,17 @@ async def test_get_annotation_spec_async_from_dict(): def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: call.return_value = annotation_spec.AnnotationSpec() client.get_annotation_spec(request) @@ -2679,28 +2482,25 @@ def test_get_annotation_spec_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.GetAnnotationSpecRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) await client.get_annotation_spec(request) @@ -2711,99 +2511,85 @@ async def test_get_annotation_spec_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_annotation_spec_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_annotation_spec( - name='name_value', - ) + client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_annotation_spec), - '__call__') as call: + type(client.transport.get_annotation_spec), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = annotation_spec.AnnotationSpec() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(annotation_spec.AnnotationSpec()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_annotation_spec( - name='name_value', - ) + response = await client.get_annotation_spec(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), - name='name_value', + dataset_service.GetAnnotationSpecRequest(), name="name_value", ) -def test_list_annotations(transport: str = 'grpc', request_type=dataset_service.ListAnnotationsRequest): +def test_list_annotations( + transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest +): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2811,13 +2597,10 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_annotations(request) @@ -2832,7 +2615,7 @@ def test_list_annotations(transport: str = 'grpc', request_type=dataset_service. assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_annotations_from_dict(): @@ -2840,10 +2623,11 @@ def test_list_annotations_from_dict(): @pytest.mark.asyncio -async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_type=dataset_service.ListAnnotationsRequest): +async def test_list_annotations_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest +): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2851,13 +2635,13 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_annotations(request) @@ -2870,7 +2654,7 @@ async def test_list_annotations_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAnnotationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2879,19 +2663,15 @@ async def test_list_annotations_async_from_dict(): def test_list_annotations_field_headers(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: call.return_value = dataset_service.ListAnnotationsResponse() client.list_annotations(request) @@ -2903,28 +2683,23 @@ def test_list_annotations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = dataset_service.ListAnnotationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) await client.list_annotations(request) @@ -2935,104 +2710,81 @@ async def test_list_annotations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_annotations_flattened(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_annotations( - parent='parent_value', - ) + client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_annotations_flattened_error(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = dataset_service.ListAnnotationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset_service.ListAnnotationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_annotations( - parent='parent_value', - ) + response = await client.list_annotations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_annotations( - dataset_service.ListAnnotationsRequest(), - parent='parent_value', + dataset_service.ListAnnotationsRequest(), parent="parent_value", ) def test_list_annotations_pager(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3041,32 +2793,23 @@ def test_list_annotations_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_annotations(request={}) @@ -3074,18 +2817,14 @@ def test_list_annotations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in results) + assert all(isinstance(i, annotation.Annotation) for i in results) + def test_list_annotations_pages(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_annotations), - '__call__') as call: + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3094,40 +2833,32 @@ def test_list_annotations_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3136,46 +2867,37 @@ async def test_list_annotations_async_pager(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) - for i in responses) + assert all(isinstance(i, annotation.Annotation) for i in responses) + @pytest.mark.asyncio async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_annotations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( dataset_service.ListAnnotationsResponse( @@ -3184,30 +2906,23 @@ async def test_list_annotations_async_pages(): annotation.Annotation(), annotation.Annotation(), ], - next_page_token='abc', + next_page_token="abc", ), dataset_service.ListAnnotationsResponse( - annotations=[], - next_page_token='def', + annotations=[], next_page_token="def", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - ], - next_page_token='ghi', + annotations=[annotation.Annotation(),], next_page_token="ghi", ), dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - ], + annotations=[annotation.Annotation(), annotation.Annotation(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_annotations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3218,8 +2933,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3238,8 +2952,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3267,13 +2980,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3281,13 +2997,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.DatasetServiceGrpcTransport, - ) + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) def test_dataset_service_base_transport_error(): @@ -3295,13 +3006,15 @@ def test_dataset_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_dataset_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.DatasetServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3310,17 +3023,17 @@ def test_dataset_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_dataset', - 'get_dataset', - 'update_dataset', - 'list_datasets', - 'delete_dataset', - 'import_data', - 'export_data', - 'list_data_items', - 'get_annotation_spec', - 'list_annotations', - ) + "create_dataset", + "get_dataset", + "update_dataset", + "list_datasets", + "delete_dataset", + "import_data", + "export_data", + "list_data_items", + "get_annotation_spec", + "list_annotations", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3333,23 +3046,28 @@ def test_dataset_service_base_transport(): def test_dataset_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_dataset_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.DatasetServiceTransport() @@ -3358,11 +3076,11 @@ def test_dataset_service_base_transport_with_adc(): def test_dataset_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) DatasetServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3370,37 +3088,43 @@ def test_dataset_service_auth_adc(): def test_dataset_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_dataset_service_host_with_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_dataset_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3408,24 +3132,33 @@ def test_dataset_service_grpc_transport_channel(): def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3434,7 +3167,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3450,9 +3183,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3460,17 +3191,23 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize("transport_class", [transports.DatasetServiceGrpcTransport, transports.DatasetServiceGrpcAsyncIOTransport]) -def test_dataset_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3487,9 +3224,7 @@ def test_dataset_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3498,16 +3233,12 @@ def test_dataset_service_transport_channel_mtls_with_adc( def test_dataset_service_grpc_lro_client(): client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3515,20 +3246,17 @@ def test_dataset_service_grpc_lro_client(): def test_dataset_service_grpc_lro_async_client(): client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_annotation_path(): project = "squid" location = "clam" @@ -3536,19 +3264,26 @@ def test_annotation_path(): data_item = "octopus" annotation = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format(project=project, location=location, dataset=dataset, data_item=data_item, annotation=annotation, ) - actual = DatasetServiceClient.annotation_path(project, location, dataset, data_item, annotation) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + actual = DatasetServiceClient.annotation_path( + project, location, dataset, data_item, annotation + ) assert expected == actual def test_parse_annotation_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - "data_item": "winkle", - "annotation": "nautilus", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", } path = DatasetServiceClient.annotation_path(**expected) @@ -3556,24 +3291,31 @@ def test_parse_annotation_path(): actual = DatasetServiceClient.parse_annotation_path(path) assert expected == actual + def test_annotation_spec_path(): project = "scallop" location = "abalone" dataset = "squid" annotation_spec = "clam" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format(project=project, location=location, dataset=dataset, annotation_spec=annotation_spec, ) - actual = DatasetServiceClient.annotation_spec_path(project, location, dataset, annotation_spec) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + actual = DatasetServiceClient.annotation_spec_path( + project, location, dataset, annotation_spec + ) assert expected == actual def test_parse_annotation_spec_path(): expected = { - "project": "whelk", - "location": "octopus", - "dataset": "oyster", - "annotation_spec": "nudibranch", - + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", } path = DatasetServiceClient.annotation_spec_path(**expected) @@ -3581,24 +3323,26 @@ def test_parse_annotation_spec_path(): actual = DatasetServiceClient.parse_annotation_spec_path(path) assert expected == actual + def test_data_item_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" data_item = "nautilus" - expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format(project=project, location=location, dataset=dataset, data_item=data_item, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) assert expected == actual def test_parse_data_item_path(): expected = { - "project": "scallop", - "location": "abalone", - "dataset": "squid", - "data_item": "clam", - + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", } path = DatasetServiceClient.data_item_path(**expected) @@ -3606,22 +3350,24 @@ def test_parse_data_item_path(): actual = DatasetServiceClient.parse_data_item_path(path) assert expected == actual + def test_dataset_path(): project = "whelk" location = "octopus" dataset = "oyster" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = DatasetServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nudibranch", - "location": "cuttlefish", - "dataset": "mussel", - + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", } path = DatasetServiceClient.dataset_path(**expected) @@ -3629,18 +3375,20 @@ def test_parse_dataset_path(): actual = DatasetServiceClient.parse_dataset_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "winkle" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = DatasetServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nautilus", - + "billing_account": "nautilus", } path = DatasetServiceClient.common_billing_account_path(**expected) @@ -3648,18 +3396,18 @@ def test_parse_common_billing_account_path(): actual = DatasetServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "scallop" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = DatasetServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "abalone", - + "folder": "abalone", } path = DatasetServiceClient.common_folder_path(**expected) @@ -3667,18 +3415,18 @@ def test_parse_common_folder_path(): actual = DatasetServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "squid" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = DatasetServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "clam", - + "organization": "clam", } path = DatasetServiceClient.common_organization_path(**expected) @@ -3686,18 +3434,18 @@ def test_parse_common_organization_path(): actual = DatasetServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "whelk" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = DatasetServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "octopus", - + "project": "octopus", } path = DatasetServiceClient.common_project_path(**expected) @@ -3705,20 +3453,22 @@ def test_parse_common_project_path(): actual = DatasetServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "oyster" location = "nudibranch" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = DatasetServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "cuttlefish", - "location": "mussel", - + "project": "cuttlefish", + "location": "mussel", } path = DatasetServiceClient.common_location_path(**expected) @@ -3730,17 +3480,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.DatasetServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = DatasetServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index 29daf6cff8..93c35a7a2a 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.endpoint_service import EndpointServiceClient +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceClient, +) from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type @@ -62,7 +66,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -73,17 +81,35 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [EndpointServiceClient, EndpointServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] +) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -91,7 +117,7 @@ def test_endpoint_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_client_get_transport_class(): @@ -102,29 +128,44 @@ def test_endpoint_service_client_get_transport_class(): assert transport == transports.EndpointServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) -def test_endpoint_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -140,7 +181,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -156,7 +197,7 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -176,13 +217,15 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -195,26 +238,66 @@ def test_endpoint_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "true"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc", "false"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(EndpointServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceClient)) -@mock.patch.object(EndpointServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(EndpointServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "true", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "false", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_endpoint_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -237,11 +320,21 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -251,7 +344,9 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -266,10 +361,17 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -284,16 +386,23 @@ def test_endpoint_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_endpoint_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -306,16 +415,24 @@ def test_endpoint_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - (EndpointServiceAsyncClient, transports.EndpointServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_endpoint_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -330,10 +447,12 @@ def test_endpoint_service_client_client_options_credentials_file(client_class, t def test_endpoint_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = EndpointServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -346,10 +465,11 @@ def test_endpoint_service_client_client_options_from_dict(): ) -def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service.CreateEndpointRequest): +def test_create_endpoint( + transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -357,11 +477,9 @@ def test_create_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_endpoint(request) @@ -380,10 +498,11 @@ def test_create_endpoint_from_dict(): @pytest.mark.asyncio -async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.CreateEndpointRequest): +async def test_create_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -391,12 +510,10 @@ async def test_create_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_endpoint(request) @@ -417,20 +534,16 @@ async def test_create_endpoint_async_from_dict(): def test_create_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_endpoint(request) @@ -441,28 +554,23 @@ def test_create_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.CreateEndpointRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_endpoint(request) @@ -473,29 +581,21 @@ async def test_create_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -503,47 +603,40 @@ def test_create_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") def test_create_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) @pytest.mark.asyncio async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.create_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_endpoint( - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -551,31 +644,30 @@ async def test_create_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") @pytest.mark.asyncio async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_endpoint( endpoint_service.CreateEndpointRequest(), - parent='parent_value', - endpoint=gca_endpoint.Endpoint(name='name_value'), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), ) -def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.GetEndpointRequest): +def test_get_endpoint( + transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -583,19 +675,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.get_endpoint(request) @@ -610,13 +696,13 @@ def test_get_endpoint(transport: str = 'grpc', request_type=endpoint_service.Get assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_endpoint_from_dict(): @@ -624,10 +710,11 @@ def test_get_endpoint_from_dict(): @pytest.mark.asyncio -async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.GetEndpointRequest): +async def test_get_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -635,16 +722,16 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.get_endpoint(request) @@ -657,13 +744,13 @@ async def test_get_endpoint_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -672,19 +759,15 @@ async def test_get_endpoint_async_from_dict(): def test_get_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = endpoint.Endpoint() client.get_endpoint(request) @@ -696,27 +779,20 @@ def test_get_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.GetEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) await client.get_endpoint(request) @@ -728,99 +804,79 @@ async def test_get_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_endpoint( - name='name_value', - ) + client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint.Endpoint() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_endpoint( - name='name_value', - ) + response = await client.get_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_endpoint( - endpoint_service.GetEndpointRequest(), - name='name_value', + endpoint_service.GetEndpointRequest(), name="name_value", ) -def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.ListEndpointsRequest): +def test_list_endpoints( + transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -828,13 +884,10 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_endpoints(request) @@ -849,7 +902,7 @@ def test_list_endpoints(transport: str = 'grpc', request_type=endpoint_service.L assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_endpoints_from_dict(): @@ -857,10 +910,11 @@ def test_list_endpoints_from_dict(): @pytest.mark.asyncio -async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.ListEndpointsRequest): +async def test_list_endpoints_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -868,13 +922,13 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_endpoints(request) @@ -887,7 +941,7 @@ async def test_list_endpoints_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListEndpointsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -896,19 +950,15 @@ async def test_list_endpoints_async_from_dict(): def test_list_endpoints_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: call.return_value = endpoint_service.ListEndpointsResponse() client.list_endpoints(request) @@ -920,28 +970,23 @@ def test_list_endpoints_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.ListEndpointsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) await client.list_endpoints(request) @@ -952,104 +997,81 @@ async def test_list_endpoints_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_endpoints_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_endpoints( - parent='parent_value', - ) + client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_endpoints_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = endpoint_service.ListEndpointsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint_service.ListEndpointsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_endpoints( - parent='parent_value', - ) + response = await client.list_endpoints(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), - parent='parent_value', + endpoint_service.ListEndpointsRequest(), parent="parent_value", ) def test_list_endpoints_pager(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1058,32 +1080,23 @@ def test_list_endpoints_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_endpoints(request={}) @@ -1091,18 +1104,14 @@ def test_list_endpoints_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in results) + assert all(isinstance(i, endpoint.Endpoint) for i in results) + def test_list_endpoints_pages(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_endpoints), - '__call__') as call: + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1111,40 +1120,32 @@ def test_list_endpoints_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1153,46 +1154,37 @@ async def test_list_endpoints_async_pager(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) - for i in responses) + assert all(isinstance(i, endpoint.Endpoint) for i in responses) + @pytest.mark.asyncio async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_endpoints), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( endpoint_service.ListEndpointsResponse( @@ -1201,37 +1193,31 @@ async def test_list_endpoints_async_pages(): endpoint.Endpoint(), endpoint.Endpoint(), ], - next_page_token='abc', + next_page_token="abc", ), endpoint_service.ListEndpointsResponse( - endpoints=[], - next_page_token='def', + endpoints=[], next_page_token="def", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - ], - next_page_token='ghi', + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", ), endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - ], + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_endpoints(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service.UpdateEndpointRequest): +def test_update_endpoint( + transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1239,19 +1225,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", ) response = client.update_endpoint(request) @@ -1266,13 +1246,13 @@ def test_update_endpoint(transport: str = 'grpc', request_type=endpoint_service. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_endpoint_from_dict(): @@ -1280,10 +1260,11 @@ def test_update_endpoint_from_dict(): @pytest.mark.asyncio -async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UpdateEndpointRequest): +async def test_update_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1291,16 +1272,16 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint( - name='name_value', - display_name='display_name_value', - description='description_value', - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) response = await client.update_endpoint(request) @@ -1313,13 +1294,13 @@ async def test_update_endpoint_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -1328,19 +1309,15 @@ async def test_update_endpoint_async_from_dict(): def test_update_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: call.return_value = gca_endpoint.Endpoint() client.update_endpoint(request) @@ -1352,28 +1329,25 @@ def test_update_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] @pytest.mark.asyncio async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = 'endpoint.name/value' + request.endpoint.name = "endpoint.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) await client.update_endpoint(request) @@ -1384,29 +1358,24 @@ async def test_update_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint.name=endpoint.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] def test_update_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1414,45 +1383,41 @@ def test_update_endpoint_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_endpoint.Endpoint() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_endpoint.Endpoint()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1460,31 +1425,30 @@ async def test_update_endpoint_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name='name_value') + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_endpoint( endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service.DeleteEndpointRequest): +def test_delete_endpoint( + transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1492,11 +1456,9 @@ def test_delete_endpoint(transport: str = 'grpc', request_type=endpoint_service. request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_endpoint(request) @@ -1515,10 +1477,11 @@ def test_delete_endpoint_from_dict(): @pytest.mark.asyncio -async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeleteEndpointRequest): +async def test_delete_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1526,12 +1489,10 @@ async def test_delete_endpoint_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_endpoint(request) @@ -1552,20 +1513,16 @@ async def test_delete_endpoint_async_from_dict(): def test_delete_endpoint_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_endpoint(request) @@ -1576,28 +1533,23 @@ def test_delete_endpoint_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeleteEndpointRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_endpoint(request) @@ -1608,101 +1560,81 @@ async def test_delete_endpoint_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_endpoint_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_endpoint( - name='name_value', - ) + client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_endpoint), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_endpoint( - name='name_value', - ) + response = await client.delete_endpoint(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), - name='name_value', + endpoint_service.DeleteEndpointRequest(), name="name_value", ) -def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.DeployModelRequest): +def test_deploy_model( + transport: str = "grpc", request_type=endpoint_service.DeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1710,11 +1642,9 @@ def test_deploy_model(transport: str = 'grpc', request_type=endpoint_service.Dep request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.deploy_model(request) @@ -1733,10 +1663,11 @@ def test_deploy_model_from_dict(): @pytest.mark.asyncio -async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.DeployModelRequest): +async def test_deploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1744,12 +1675,10 @@ async def test_deploy_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.deploy_model(request) @@ -1770,20 +1699,16 @@ async def test_deploy_model_async_from_dict(): def test_deploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.deploy_model(request) @@ -1794,28 +1719,23 @@ def test_deploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.DeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.deploy_model(request) @@ -1826,30 +1746,29 @@ async def test_deploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_deploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -1857,51 +1776,63 @@ def test_deploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_deploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.deploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.deploy_model( - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -1909,34 +1840,45 @@ async def test_deploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))) + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.deploy_model( endpoint_service.DeployModelRequest(), - endpoint='endpoint_value', - deployed_model=gca_endpoint.DeployedModel(dedicated_resources=machine_resources.DedicatedResources(machine_spec=machine_resources.MachineSpec(machine_type='machine_type_value'))), - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, ) -def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.UndeployModelRequest): +def test_undeploy_model( + transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest +): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1944,11 +1886,9 @@ def test_undeploy_model(transport: str = 'grpc', request_type=endpoint_service.U request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.undeploy_model(request) @@ -1967,10 +1907,11 @@ def test_undeploy_model_from_dict(): @pytest.mark.asyncio -async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_type=endpoint_service.UndeployModelRequest): +async def test_undeploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest +): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1978,12 +1919,10 @@ async def test_undeploy_model_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.undeploy_model(request) @@ -2004,20 +1943,16 @@ async def test_undeploy_model_async_from_dict(): def test_undeploy_model_field_headers(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.undeploy_model(request) @@ -2028,28 +1963,23 @@ def test_undeploy_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] @pytest.mark.asyncio async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = endpoint_service.UndeployModelRequest() - request.endpoint = 'endpoint/value' + request.endpoint = "endpoint/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.undeploy_model(request) @@ -2060,30 +1990,23 @@ async def test_undeploy_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'endpoint=endpoint/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] def test_undeploy_model_flattened(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2091,51 +2014,45 @@ def test_undeploy_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} def test_undeploy_model_flattened_error(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @pytest.mark.asyncio async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.undeploy_model), - '__call__') as call: + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.undeploy_model( - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) # Establish that the underlying call was made with the expected @@ -2143,27 +2060,25 @@ async def test_undeploy_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].endpoint == 'endpoint_value' + assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == 'deployed_model_id_value' + assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {'key_value': 541} + assert args[0].traffic_split == {"key_value": 541} @pytest.mark.asyncio async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.undeploy_model( endpoint_service.UndeployModelRequest(), - endpoint='endpoint_value', - deployed_model_id='deployed_model_id_value', - traffic_split={'key_value': 541}, + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, ) @@ -2174,8 +2089,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -2194,8 +2108,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -2223,13 +2136,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -2237,13 +2153,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.EndpointServiceGrpcTransport, - ) + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) def test_endpoint_service_base_transport_error(): @@ -2251,13 +2162,15 @@ def test_endpoint_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_endpoint_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.EndpointServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -2266,14 +2179,14 @@ def test_endpoint_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_endpoint', - 'get_endpoint', - 'list_endpoints', - 'update_endpoint', - 'delete_endpoint', - 'deploy_model', - 'undeploy_model', - ) + "create_endpoint", + "get_endpoint", + "list_endpoints", + "update_endpoint", + "delete_endpoint", + "deploy_model", + "undeploy_model", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -2286,23 +2199,28 @@ def test_endpoint_service_base_transport(): def test_endpoint_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_endpoint_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.EndpointServiceTransport() @@ -2311,11 +2229,11 @@ def test_endpoint_service_base_transport_with_adc(): def test_endpoint_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) EndpointServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -2323,37 +2241,43 @@ def test_endpoint_service_auth_adc(): def test_endpoint_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_endpoint_service_host_with_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_endpoint_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -2361,24 +2285,33 @@ def test_endpoint_service_grpc_transport_channel(): def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2387,7 +2320,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2403,9 +2336,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -2413,17 +2344,23 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize("transport_class", [transports.EndpointServiceGrpcTransport, transports.EndpointServiceGrpcAsyncIOTransport]) -def test_endpoint_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2440,9 +2377,7 @@ def test_endpoint_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -2451,16 +2386,12 @@ def test_endpoint_service_transport_channel_mtls_with_adc( def test_endpoint_service_grpc_lro_client(): client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -2468,36 +2399,34 @@ def test_endpoint_service_grpc_lro_client(): def test_endpoint_service_grpc_lro_async_client(): client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = EndpointServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = EndpointServiceClient.endpoint_path(**expected) @@ -2505,22 +2434,24 @@ def test_parse_endpoint_path(): actual = EndpointServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = EndpointServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = EndpointServiceClient.model_path(**expected) @@ -2528,18 +2459,20 @@ def test_parse_model_path(): actual = EndpointServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = EndpointServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = EndpointServiceClient.common_billing_account_path(**expected) @@ -2547,18 +2480,18 @@ def test_parse_common_billing_account_path(): actual = EndpointServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = EndpointServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = EndpointServiceClient.common_folder_path(**expected) @@ -2566,18 +2499,18 @@ def test_parse_common_folder_path(): actual = EndpointServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = EndpointServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = EndpointServiceClient.common_organization_path(**expected) @@ -2585,18 +2518,18 @@ def test_parse_common_organization_path(): actual = EndpointServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = EndpointServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = EndpointServiceClient.common_project_path(**expected) @@ -2604,20 +2537,22 @@ def test_parse_common_project_path(): actual = EndpointServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = EndpointServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = EndpointServiceClient.common_location_path(**expected) @@ -2629,17 +2564,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.EndpointServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = EndpointServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index a5543f7767..f08d84bd2f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -41,14 +41,22 @@ from google.cloud.aiplatform_v1beta1.services.job_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job as gca_batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) from google.cloud.aiplatform_v1beta1.types import completion_stats from google.cloud.aiplatform_v1beta1.types import custom_job from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job as gca_hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_service from google.cloud.aiplatform_v1beta1.types import job_state @@ -75,7 +83,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -86,17 +98,30 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi @pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -104,7 +129,7 @@ def test_job_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_client_get_transport_class(): @@ -115,29 +140,42 @@ def test_job_service_client_get_transport_class(): assert transport == transports.JobServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) -def test_job_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -153,7 +191,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -169,7 +207,7 @@ def test_job_service_client_client_options(client_class, transport_class, transp # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -189,13 +227,15 @@ def test_job_service_client_client_options(client_class, transport_class, transp client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -208,26 +248,54 @@ def test_job_service_client_client_options(client_class, transport_class, transp client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient)) -@mock.patch.object(JobServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_job_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -250,11 +318,21 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -264,7 +342,9 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -279,10 +359,17 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -297,16 +384,23 @@ def test_job_service_client_mtls_env_auto(client_class, transport_class, transpo ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_job_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -319,16 +413,24 @@ def test_job_service_client_client_options_scopes(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - (JobServiceAsyncClient, transports.JobServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_job_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -343,11 +445,11 @@ def test_job_service_client_client_options_credentials_file(client_class, transp def test_job_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = JobServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -359,10 +461,11 @@ def test_job_service_client_client_options_from_dict(): ) -def test_create_custom_job(transport: str = 'grpc', request_type=job_service.CreateCustomJobRequest): +def test_create_custom_job( + transport: str = "grpc", request_type=job_service.CreateCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -371,16 +474,13 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_custom_job(request) @@ -395,9 +495,9 @@ def test_create_custom_job(transport: str = 'grpc', request_type=job_service.Cre assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -407,10 +507,11 @@ def test_create_custom_job_from_dict(): @pytest.mark.asyncio -async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateCustomJobRequest): +async def test_create_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -419,14 +520,16 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_custom_job(request) @@ -439,9 +542,9 @@ async def test_create_custom_job_async(transport: str = 'grpc_asyncio', request_ # Establish that the response is the type that we expect. assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -452,19 +555,17 @@ async def test_create_custom_job_async_from_dict(): def test_create_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: call.return_value = gca_custom_job.CustomJob() client.create_custom_job(request) @@ -476,28 +577,25 @@ def test_create_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateCustomJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) await client.create_custom_job(request) @@ -508,29 +606,24 @@ async def test_create_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -538,45 +631,43 @@ def test_create_custom_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") def test_create_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_custom_job), - '__call__') as call: + type(client.transport.create_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_custom_job( - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -584,31 +675,30 @@ async def test_create_custom_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name='name_value') + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") @pytest.mark.asyncio async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_custom_job( job_service.CreateCustomJobRequest(), - parent='parent_value', - custom_job=gca_custom_job.CustomJob(name='name_value'), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), ) -def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCustomJobRequest): +def test_get_custom_job( + transport: str = "grpc", request_type=job_service.GetCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -616,17 +706,12 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_custom_job(request) @@ -641,9 +726,9 @@ def test_get_custom_job(transport: str = 'grpc', request_type=job_service.GetCus assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -653,10 +738,11 @@ def test_get_custom_job_from_dict(): @pytest.mark.asyncio -async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetCustomJobRequest): +async def test_get_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -664,15 +750,15 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob( - name='name_value', - display_name='display_name_value', - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_custom_job(request) @@ -685,9 +771,9 @@ async def test_get_custom_job_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, custom_job.CustomJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED @@ -698,19 +784,15 @@ async def test_get_custom_job_async_from_dict(): def test_get_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: call.return_value = custom_job.CustomJob() client.get_custom_job(request) @@ -722,28 +804,23 @@ def test_get_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) await client.get_custom_job(request) @@ -754,99 +831,81 @@ async def test_get_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_custom_job( - name='name_value', - ) + client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_custom_job), - '__call__') as call: + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = custom_job.CustomJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(custom_job.CustomJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_custom_job( - name='name_value', - ) + response = await client.get_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_custom_job( - job_service.GetCustomJobRequest(), - name='name_value', + job_service.GetCustomJobRequest(), name="name_value", ) -def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.ListCustomJobsRequest): +def test_list_custom_jobs( + transport: str = "grpc", request_type=job_service.ListCustomJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -854,13 +913,10 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_custom_jobs(request) @@ -875,7 +931,7 @@ def test_list_custom_jobs(transport: str = 'grpc', request_type=job_service.List assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_custom_jobs_from_dict(): @@ -883,10 +939,11 @@ def test_list_custom_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListCustomJobsRequest): +async def test_list_custom_jobs_async( + transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -894,13 +951,11 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_custom_jobs(request) @@ -913,7 +968,7 @@ async def test_list_custom_jobs_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListCustomJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -922,19 +977,15 @@ async def test_list_custom_jobs_async_from_dict(): def test_list_custom_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: call.return_value = job_service.ListCustomJobsResponse() client.list_custom_jobs(request) @@ -946,28 +997,23 @@ def test_list_custom_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListCustomJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) await client.list_custom_jobs(request) @@ -978,104 +1024,81 @@ async def test_list_custom_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_custom_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_custom_jobs( - parent='parent_value', - ) + client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_custom_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListCustomJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListCustomJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_custom_jobs( - parent='parent_value', - ) + response = await client.list_custom_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), - parent='parent_value', + job_service.ListCustomJobsRequest(), parent="parent_value", ) def test_list_custom_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1084,32 +1107,21 @@ def test_list_custom_jobs_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_custom_jobs(request={}) @@ -1117,18 +1129,14 @@ def test_list_custom_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in results) + assert all(isinstance(i, custom_job.CustomJob) for i in results) + def test_list_custom_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__') as call: + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1137,40 +1145,30 @@ def test_list_custom_jobs_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1179,46 +1177,35 @@ async def test_list_custom_jobs_async_pager(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) - for i in responses) + assert all(isinstance(i, custom_job.CustomJob) for i in responses) + @pytest.mark.asyncio async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_custom_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListCustomJobsResponse( @@ -1227,37 +1214,29 @@ async def test_list_custom_jobs_async_pages(): custom_job.CustomJob(), custom_job.CustomJob(), ], - next_page_token='abc', + next_page_token="abc", ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), job_service.ListCustomJobsResponse( - custom_jobs=[], - next_page_token='def', + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", ), job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - ], - next_page_token='ghi', - ), - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - ], + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], ), RuntimeError, ) pages = [] async for page_ in (await client.list_custom_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.DeleteCustomJobRequest): +def test_delete_custom_job( + transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1266,10 +1245,10 @@ def test_delete_custom_job(transport: str = 'grpc', request_type=job_service.Del # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_custom_job(request) @@ -1288,10 +1267,11 @@ def test_delete_custom_job_from_dict(): @pytest.mark.asyncio -async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteCustomJobRequest): +async def test_delete_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1300,11 +1280,11 @@ async def test_delete_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_custom_job(request) @@ -1325,20 +1305,18 @@ async def test_delete_custom_job_async_from_dict(): def test_delete_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_custom_job(request) @@ -1349,28 +1327,25 @@ def test_delete_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_custom_job(request) @@ -1381,101 +1356,85 @@ async def test_delete_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_custom_job( - name='name_value', - ) + client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_custom_job), - '__call__') as call: + type(client.transport.delete_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_custom_job( - name='name_value', - ) + response = await client.delete_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), - name='name_value', + job_service.DeleteCustomJobRequest(), name="name_value", ) -def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.CancelCustomJobRequest): +def test_cancel_custom_job( + transport: str = "grpc", request_type=job_service.CancelCustomJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1484,8 +1443,8 @@ def test_cancel_custom_job(transport: str = 'grpc', request_type=job_service.Can # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1506,10 +1465,11 @@ def test_cancel_custom_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelCustomJobRequest): +async def test_cancel_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1518,8 +1478,8 @@ async def test_cancel_custom_job_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1541,19 +1501,17 @@ async def test_cancel_custom_job_async_from_dict(): def test_cancel_custom_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = None client.cancel_custom_job(request) @@ -1565,27 +1523,22 @@ def test_cancel_custom_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelCustomJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_custom_job(request) @@ -1597,99 +1550,83 @@ async def test_cancel_custom_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_custom_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_custom_job( - name='name_value', - ) + client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_custom_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_custom_job), - '__call__') as call: + type(client.transport.cancel_custom_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_custom_job( - name='name_value', - ) + response = await client.cancel_custom_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), - name='name_value', + job_service.CancelCustomJobRequest(), name="name_value", ) -def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_service.CreateDataLabelingJobRequest): +def test_create_data_labeling_job( + transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1698,28 +1635,19 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.create_data_labeling_job(request) @@ -1734,23 +1662,23 @@ def test_create_data_labeling_job(transport: str = 'grpc', request_type=job_serv assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_create_data_labeling_job_from_dict(): @@ -1758,10 +1686,12 @@ def test_create_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateDataLabelingJobRequest): +async def test_create_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1770,20 +1700,22 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.create_data_labeling_job(request) @@ -1796,23 +1728,23 @@ async def test_create_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] @pytest.mark.asyncio @@ -1821,19 +1753,17 @@ async def test_create_data_labeling_job_async_from_dict(): def test_create_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: call.return_value = gca_data_labeling_job.DataLabelingJob() client.create_data_labeling_job(request) @@ -1845,28 +1775,25 @@ def test_create_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateDataLabelingJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) await client.create_data_labeling_job(request) @@ -1877,29 +1804,24 @@ async def test_create_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -1907,45 +1829,45 @@ def test_create_data_labeling_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_data_labeling_job), - '__call__') as call: + type(client.transport.create_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_data_labeling_job( - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -1953,31 +1875,32 @@ async def test_create_data_labeling_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob(name='name_value') + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_data_labeling_job( job_service.CreateDataLabelingJobRequest(), - parent='parent_value', - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name='name_value'), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), ) -def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service.GetDataLabelingJobRequest): +def test_get_data_labeling_job( + transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1986,28 +1909,19 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob( - name='name_value', - - display_name='display_name_value', - - datasets=['datasets_value'], - + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], labeler_count=1375, - - instruction_uri='instruction_uri_value', - - inputs_schema_uri='inputs_schema_uri_value', - + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - - specialist_pools=['specialist_pools_value'], - + specialist_pools=["specialist_pools_value"], ) response = client.get_data_labeling_job(request) @@ -2022,23 +1936,23 @@ def test_get_data_labeling_job(transport: str = 'grpc', request_type=job_service assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] def test_get_data_labeling_job_from_dict(): @@ -2046,10 +1960,11 @@ def test_get_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetDataLabelingJobRequest): +async def test_get_data_labeling_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2058,20 +1973,22 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob( - name='name_value', - display_name='display_name_value', - datasets=['datasets_value'], - labeler_count=1375, - instruction_uri='instruction_uri_value', - inputs_schema_uri='inputs_schema_uri_value', - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=['specialist_pools_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) response = await client.get_data_labeling_job(request) @@ -2084,23 +2001,23 @@ async def test_get_data_labeling_job_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.datasets == ['datasets_value'] + assert response.datasets == ["datasets_value"] assert response.labeler_count == 1375 - assert response.instruction_uri == 'instruction_uri_value' + assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == 'inputs_schema_uri_value' + assert response.inputs_schema_uri == "inputs_schema_uri_value" assert response.state == job_state.JobState.JOB_STATE_QUEUED assert response.labeling_progress == 1810 - assert response.specialist_pools == ['specialist_pools_value'] + assert response.specialist_pools == ["specialist_pools_value"] @pytest.mark.asyncio @@ -2109,19 +2026,17 @@ async def test_get_data_labeling_job_async_from_dict(): def test_get_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: call.return_value = data_labeling_job.DataLabelingJob() client.get_data_labeling_job(request) @@ -2133,28 +2048,25 @@ def test_get_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) await client.get_data_labeling_job(request) @@ -2165,99 +2077,85 @@ async def test_get_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_data_labeling_job( - name='name_value', - ) + client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_data_labeling_job), - '__call__') as call: + type(client.transport.get_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = data_labeling_job.DataLabelingJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(data_labeling_job.DataLabelingJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_data_labeling_job( - name='name_value', - ) + response = await client.get_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), - name='name_value', + job_service.GetDataLabelingJobRequest(), name="name_value", ) -def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_service.ListDataLabelingJobsRequest): +def test_list_data_labeling_jobs( + transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2266,12 +2164,11 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_data_labeling_jobs(request) @@ -2286,7 +2183,7 @@ def test_list_data_labeling_jobs(transport: str = 'grpc', request_type=job_servi assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_data_labeling_jobs_from_dict(): @@ -2294,10 +2191,12 @@ def test_list_data_labeling_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListDataLabelingJobsRequest): +async def test_list_data_labeling_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListDataLabelingJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2306,12 +2205,14 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_data_labeling_jobs(request) @@ -2324,7 +2225,7 @@ async def test_list_data_labeling_jobs_async(transport: str = 'grpc_asyncio', re # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2333,19 +2234,17 @@ async def test_list_data_labeling_jobs_async_from_dict(): def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: call.return_value = job_service.ListDataLabelingJobsResponse() client.list_data_labeling_jobs(request) @@ -2357,28 +2256,25 @@ def test_list_data_labeling_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListDataLabelingJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) await client.list_data_labeling_jobs(request) @@ -2389,104 +2285,87 @@ async def test_list_data_labeling_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_data_labeling_jobs( - parent='parent_value', - ) + client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListDataLabelingJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListDataLabelingJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs( - parent='parent_value', - ) + response = await client.list_data_labeling_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), - parent='parent_value', + job_service.ListDataLabelingJobsRequest(), parent="parent_value", ) def test_list_data_labeling_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2495,17 +2374,14 @@ def test_list_data_labeling_jobs_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2518,9 +2394,7 @@ def test_list_data_labeling_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_data_labeling_jobs(request={}) @@ -2528,18 +2402,16 @@ def test_list_data_labeling_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in results) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) + def test_list_data_labeling_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__') as call: + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2548,17 +2420,14 @@ def test_list_data_labeling_jobs_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2569,19 +2438,20 @@ def test_list_data_labeling_jobs_pages(): RuntimeError, ) pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2590,17 +2460,14 @@ async def test_list_data_labeling_jobs_async_pager(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2611,25 +2478,25 @@ async def test_list_data_labeling_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) - for i in responses) + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_data_labeling_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListDataLabelingJobsResponse( @@ -2638,17 +2505,14 @@ async def test_list_data_labeling_jobs_async_pages(): data_labeling_job.DataLabelingJob(), data_labeling_job.DataLabelingJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], - next_page_token='def', + data_labeling_jobs=[], next_page_token="def", ), job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - ], - next_page_token='ghi', + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", ), job_service.ListDataLabelingJobsResponse( data_labeling_jobs=[ @@ -2661,14 +2525,15 @@ async def test_list_data_labeling_jobs_async_pages(): pages = [] async for page_ in (await client.list_data_labeling_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_service.DeleteDataLabelingJobRequest): +def test_delete_data_labeling_job( + transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2677,10 +2542,10 @@ def test_delete_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_data_labeling_job(request) @@ -2699,10 +2564,12 @@ def test_delete_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteDataLabelingJobRequest): +async def test_delete_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2711,11 +2578,11 @@ async def test_delete_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_data_labeling_job(request) @@ -2736,20 +2603,18 @@ async def test_delete_data_labeling_job_async_from_dict(): def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_data_labeling_job(request) @@ -2760,28 +2625,25 @@ def test_delete_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_data_labeling_job(request) @@ -2792,101 +2654,85 @@ async def test_delete_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_data_labeling_job( - name='name_value', - ) + client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_data_labeling_job), - '__call__') as call: + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_data_labeling_job( - name='name_value', - ) + response = await client.delete_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), - name='name_value', + job_service.DeleteDataLabelingJobRequest(), name="name_value", ) -def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_service.CancelDataLabelingJobRequest): +def test_cancel_data_labeling_job( + transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2895,8 +2741,8 @@ def test_cancel_data_labeling_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2917,10 +2763,12 @@ def test_cancel_data_labeling_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelDataLabelingJobRequest): +async def test_cancel_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelDataLabelingJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2929,8 +2777,8 @@ async def test_cancel_data_labeling_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -2952,19 +2800,17 @@ async def test_cancel_data_labeling_job_async_from_dict(): def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = None client.cancel_data_labeling_job(request) @@ -2976,27 +2822,22 @@ def test_cancel_data_labeling_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelDataLabelingJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_data_labeling_job(request) @@ -3008,99 +2849,84 @@ async def test_cancel_data_labeling_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_data_labeling_job( - name='name_value', - ) + client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_data_labeling_job), - '__call__') as call: + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job( - name='name_value', - ) + response = await client.cancel_data_labeling_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), - name='name_value', + job_service.CancelDataLabelingJobRequest(), name="name_value", ) -def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CreateHyperparameterTuningJobRequest): +def test_create_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3109,22 +2935,16 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_hyperparameter_tuning_job(request) @@ -3139,9 +2959,9 @@ def test_create_hyperparameter_tuning_job(transport: str = 'grpc', request_type= assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3157,10 +2977,12 @@ def test_create_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateHyperparameterTuningJobRequest): +async def test_create_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3169,17 +2991,19 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_hyperparameter_tuning_job(request) @@ -3192,9 +3016,9 @@ async def test_create_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Establish that the response is the type that we expect. assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3211,19 +3035,17 @@ async def test_create_hyperparameter_tuning_job_async_from_dict(): def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() client.create_hyperparameter_tuning_job(request) @@ -3235,28 +3057,25 @@ def test_create_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.create_hyperparameter_tuning_job(request) @@ -3267,29 +3086,26 @@ async def test_create_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3297,45 +3113,51 @@ def test_create_hyperparameter_tuning_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_hyperparameter_tuning_job( - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -3343,31 +3165,36 @@ async def test_create_hyperparameter_tuning_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value') + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_hyperparameter_tuning_job( job_service.CreateHyperparameterTuningJobRequest(), - parent='parent_value', - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob(name='name_value'), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), ) -def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.GetHyperparameterTuningJobRequest): +def test_get_hyperparameter_tuning_job( + transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3376,22 +3203,16 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_hyperparameter_tuning_job(request) @@ -3406,9 +3227,9 @@ def test_get_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3424,10 +3245,12 @@ def test_get_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetHyperparameterTuningJobRequest): +async def test_get_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3436,17 +3259,19 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob( - name='name_value', - display_name='display_name_value', - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_hyperparameter_tuning_job(request) @@ -3459,9 +3284,9 @@ async def test_get_hyperparameter_tuning_job_async(transport: str = 'grpc_asynci # Establish that the response is the type that we expect. assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.max_trial_count == 1609 @@ -3478,19 +3303,17 @@ async def test_get_hyperparameter_tuning_job_async_from_dict(): def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() client.get_hyperparameter_tuning_job(request) @@ -3502,28 +3325,25 @@ def test_get_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) await client.get_hyperparameter_tuning_job(request) @@ -3534,99 +3354,86 @@ async def test_get_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job( - name='name_value', - ) + client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(hyperparameter_tuning_job.HyperparameterTuningJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.get_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), - name='name_value', + job_service.GetHyperparameterTuningJobRequest(), name="name_value", ) -def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=job_service.ListHyperparameterTuningJobsRequest): +def test_list_hyperparameter_tuning_jobs( + transport: str = "grpc", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3635,12 +3442,11 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_hyperparameter_tuning_jobs(request) @@ -3655,7 +3461,7 @@ def test_list_hyperparameter_tuning_jobs(transport: str = 'grpc', request_type=j assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_hyperparameter_tuning_jobs_from_dict(): @@ -3663,10 +3469,12 @@ def test_list_hyperparameter_tuning_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListHyperparameterTuningJobsRequest): +async def test_list_hyperparameter_tuning_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -3675,12 +3483,14 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_hyperparameter_tuning_jobs(request) @@ -3693,7 +3503,7 @@ async def test_list_hyperparameter_tuning_jobs_async(transport: str = 'grpc_asyn # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -3702,19 +3512,17 @@ async def test_list_hyperparameter_tuning_jobs_async_from_dict(): def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: call.return_value = job_service.ListHyperparameterTuningJobsResponse() client.list_hyperparameter_tuning_jobs(request) @@ -3726,28 +3534,25 @@ def test_list_hyperparameter_tuning_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) await client.list_hyperparameter_tuning_jobs(request) @@ -3758,104 +3563,87 @@ async def test_list_hyperparameter_tuning_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListHyperparameterTuningJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListHyperparameterTuningJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs( - parent='parent_value', - ) + response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), - parent='parent_value', + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", ) def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3864,17 +3652,16 @@ def test_list_hyperparameter_tuning_jobs_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3887,9 +3674,7 @@ def test_list_hyperparameter_tuning_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_hyperparameter_tuning_jobs(request={}) @@ -3897,18 +3682,19 @@ def test_list_hyperparameter_tuning_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results + ) + def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__') as call: + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3917,17 +3703,16 @@ def test_list_hyperparameter_tuning_jobs_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3938,19 +3723,20 @@ def test_list_hyperparameter_tuning_jobs_pages(): RuntimeError, ) pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -3959,17 +3745,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -3980,25 +3765,28 @@ async def test_list_hyperparameter_tuning_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses) + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses + ) + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_hyperparameter_tuning_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListHyperparameterTuningJobsResponse( @@ -4007,17 +3795,16 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): hyperparameter_tuning_job.HyperparameterTuningJob(), hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], - next_page_token='def', + hyperparameter_tuning_jobs=[], next_page_token="def", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ hyperparameter_tuning_job.HyperparameterTuningJob(), ], - next_page_token='ghi', + next_page_token="ghi", ), job_service.ListHyperparameterTuningJobsResponse( hyperparameter_tuning_jobs=[ @@ -4028,16 +3815,20 @@ async def test_list_hyperparameter_tuning_jobs_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_hyperparameter_tuning_jobs(request={})).pages: + async for page_ in ( + await client.list_hyperparameter_tuning_jobs(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.DeleteHyperparameterTuningJobRequest): +def test_delete_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4046,10 +3837,10 @@ def test_delete_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_hyperparameter_tuning_job(request) @@ -4068,10 +3859,12 @@ def test_delete_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteHyperparameterTuningJobRequest): +async def test_delete_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4080,11 +3873,11 @@ async def test_delete_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_hyperparameter_tuning_job(request) @@ -4105,20 +3898,18 @@ async def test_delete_hyperparameter_tuning_job_async_from_dict(): def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_hyperparameter_tuning_job(request) @@ -4129,28 +3920,25 @@ def test_delete_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_hyperparameter_tuning_job(request) @@ -4161,101 +3949,86 @@ async def test_delete_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job( - name='name_value', - ) + client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.delete_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), - name='name_value', + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", ) -def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type=job_service.CancelHyperparameterTuningJobRequest): +def test_cancel_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4264,8 +4037,8 @@ def test_cancel_hyperparameter_tuning_job(transport: str = 'grpc', request_type= # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -4286,10 +4059,12 @@ def test_cancel_hyperparameter_tuning_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelHyperparameterTuningJobRequest): +async def test_cancel_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4298,8 +4073,8 @@ async def test_cancel_hyperparameter_tuning_job_async(transport: str = 'grpc_asy # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -4321,19 +4096,17 @@ async def test_cancel_hyperparameter_tuning_job_async_from_dict(): def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = None client.cancel_hyperparameter_tuning_job(request) @@ -4345,27 +4118,22 @@ def test_cancel_hyperparameter_tuning_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelHyperparameterTuningJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_hyperparameter_tuning_job(request) @@ -4377,99 +4145,83 @@ async def test_cancel_hyperparameter_tuning_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_hyperparameter_tuning_job), - '__call__') as call: + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job( - name='name_value', - ) + response = await client.cancel_hyperparameter_tuning_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), - name='name_value', + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", ) -def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CreateBatchPredictionJobRequest): +def test_create_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4478,20 +4230,15 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.create_batch_prediction_job(request) @@ -4506,11 +4253,11 @@ def test_create_batch_prediction_job(transport: str = 'grpc', request_type=job_s assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4522,10 +4269,12 @@ def test_create_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CreateBatchPredictionJobRequest): +async def test_create_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4534,16 +4283,18 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.create_batch_prediction_job(request) @@ -4556,11 +4307,11 @@ async def test_create_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4573,19 +4324,17 @@ async def test_create_batch_prediction_job_async_from_dict(): def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: call.return_value = gca_batch_prediction_job.BatchPredictionJob() client.create_batch_prediction_job(request) @@ -4597,28 +4346,25 @@ def test_create_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CreateBatchPredictionJobRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) await client.create_batch_prediction_job(request) @@ -4629,29 +4375,26 @@ async def test_create_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -4659,45 +4402,51 @@ def test_create_batch_prediction_job_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_batch_prediction_job), - '__call__') as call: + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_batch_prediction_job( - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) # Establish that the underlying call was made with the expected @@ -4705,31 +4454,36 @@ async def test_create_batch_prediction_job_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob(name='name_value') + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) @pytest.mark.asyncio async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_batch_prediction_job( job_service.CreateBatchPredictionJobRequest(), - parent='parent_value', - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob(name='name_value'), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), ) -def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_service.GetBatchPredictionJobRequest): +def test_get_batch_prediction_job( + transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4738,20 +4492,15 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob( - name='name_value', - - display_name='display_name_value', - - model='model_value', - + name="name_value", + display_name="display_name_value", + model="model_value", generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) response = client.get_batch_prediction_job(request) @@ -4766,11 +4515,11 @@ def test_get_batch_prediction_job(transport: str = 'grpc', request_type=job_serv assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4782,10 +4531,12 @@ def test_get_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.GetBatchPredictionJobRequest): +async def test_get_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4794,16 +4545,18 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob( - name='name_value', - display_name='display_name_value', - model='model_value', - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) response = await client.get_batch_prediction_job(request) @@ -4816,11 +4569,11 @@ async def test_get_batch_prediction_job_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.model == 'model_value' + assert response.model == "model_value" assert response.generate_explanation is True @@ -4833,19 +4586,17 @@ async def test_get_batch_prediction_job_async_from_dict(): def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: call.return_value = batch_prediction_job.BatchPredictionJob() client.get_batch_prediction_job(request) @@ -4857,28 +4608,25 @@ def test_get_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.GetBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) await client.get_batch_prediction_job(request) @@ -4889,99 +4637,85 @@ async def test_get_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_batch_prediction_job( - name='name_value', - ) + client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_batch_prediction_job), - '__call__') as call: + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = batch_prediction_job.BatchPredictionJob() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(batch_prediction_job.BatchPredictionJob()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_batch_prediction_job( - name='name_value', - ) + response = await client.get_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), - name='name_value', + job_service.GetBatchPredictionJobRequest(), name="name_value", ) -def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_service.ListBatchPredictionJobsRequest): +def test_list_batch_prediction_jobs( + transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -4990,12 +4724,11 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_batch_prediction_jobs(request) @@ -5010,7 +4743,7 @@ def test_list_batch_prediction_jobs(transport: str = 'grpc', request_type=job_se assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_batch_prediction_jobs_from_dict(): @@ -5018,10 +4751,12 @@ def test_list_batch_prediction_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', request_type=job_service.ListBatchPredictionJobsRequest): +async def test_list_batch_prediction_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListBatchPredictionJobsRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5030,12 +4765,14 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_batch_prediction_jobs(request) @@ -5048,7 +4785,7 @@ async def test_list_batch_prediction_jobs_async(transport: str = 'grpc_asyncio', # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -5057,19 +4794,17 @@ async def test_list_batch_prediction_jobs_async_from_dict(): def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: call.return_value = job_service.ListBatchPredictionJobsResponse() client.list_batch_prediction_jobs(request) @@ -5081,28 +4816,25 @@ def test_list_batch_prediction_jobs_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.ListBatchPredictionJobsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) await client.list_batch_prediction_jobs(request) @@ -5113,104 +4845,87 @@ async def test_list_batch_prediction_jobs_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_batch_prediction_jobs( - parent='parent_value', - ) + client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = job_service.ListBatchPredictionJobsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(job_service.ListBatchPredictionJobsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs( - parent='parent_value', - ) + response = await client.list_batch_prediction_jobs(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), - parent='parent_value', + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", ) def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5219,17 +4934,14 @@ def test_list_batch_prediction_jobs_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5242,9 +4954,7 @@ def test_list_batch_prediction_jobs_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_batch_prediction_jobs(request={}) @@ -5252,18 +4962,18 @@ def test_list_batch_prediction_jobs_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in results) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results + ) + def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__') as call: + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5272,17 +4982,14 @@ def test_list_batch_prediction_jobs_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5293,19 +5000,20 @@ def test_list_batch_prediction_jobs_pages(): RuntimeError, ) pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5314,17 +5022,14 @@ async def test_list_batch_prediction_jobs_async_pager(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5335,25 +5040,27 @@ async def test_list_batch_prediction_jobs_async_pager(): RuntimeError, ) async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, batch_prediction_job.BatchPredictionJob) - for i in responses) + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses + ) + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_batch_prediction_jobs), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( job_service.ListBatchPredictionJobsResponse( @@ -5362,17 +5069,14 @@ async def test_list_batch_prediction_jobs_async_pages(): batch_prediction_job.BatchPredictionJob(), batch_prediction_job.BatchPredictionJob(), ], - next_page_token='abc', + next_page_token="abc", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], - next_page_token='def', + batch_prediction_jobs=[], next_page_token="def", ), job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token='ghi', + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", ), job_service.ListBatchPredictionJobsResponse( batch_prediction_jobs=[ @@ -5385,14 +5089,15 @@ async def test_list_batch_prediction_jobs_async_pages(): pages = [] async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_service.DeleteBatchPredictionJobRequest): +def test_delete_batch_prediction_job( + transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5401,10 +5106,10 @@ def test_delete_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_batch_prediction_job(request) @@ -5423,10 +5128,12 @@ def test_delete_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.DeleteBatchPredictionJobRequest): +async def test_delete_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5435,11 +5142,11 @@ async def test_delete_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_batch_prediction_job(request) @@ -5460,20 +5167,18 @@ async def test_delete_batch_prediction_job_async_from_dict(): def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_batch_prediction_job(request) @@ -5484,28 +5189,25 @@ def test_delete_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.DeleteBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_batch_prediction_job(request) @@ -5516,101 +5218,85 @@ async def test_delete_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_batch_prediction_job( - name='name_value', - ) + client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_batch_prediction_job), - '__call__') as call: + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job( - name='name_value', - ) + response = await client.delete_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), - name='name_value', + job_service.DeleteBatchPredictionJobRequest(), name="name_value", ) -def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_service.CancelBatchPredictionJobRequest): +def test_cancel_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest +): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5619,8 +5305,8 @@ def test_cancel_batch_prediction_job(transport: str = 'grpc', request_type=job_s # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -5641,10 +5327,12 @@ def test_cancel_batch_prediction_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio', request_type=job_service.CancelBatchPredictionJobRequest): +async def test_cancel_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelBatchPredictionJobRequest, +): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -5653,8 +5341,8 @@ async def test_cancel_batch_prediction_job_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -5676,19 +5364,17 @@ async def test_cancel_batch_prediction_job_async_from_dict(): def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = None client.cancel_batch_prediction_job(request) @@ -5700,27 +5386,22 @@ def test_cancel_batch_prediction_job_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = job_service.CancelBatchPredictionJobRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_batch_prediction_job(request) @@ -5732,92 +5413,75 @@ async def test_cancel_batch_prediction_job_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_batch_prediction_job( - name='name_value', - ) + client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_batch_prediction_job), - '__call__') as call: + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job( - name='name_value', - ) + response = await client.cancel_batch_prediction_job(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), - name='name_value', + job_service.CancelBatchPredictionJobRequest(), name="name_value", ) @@ -5828,8 +5492,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -5848,8 +5511,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -5877,13 +5539,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.JobServiceGrpcTransport, - transports.JobServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -5891,13 +5553,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.JobServiceGrpcTransport, - ) + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.JobServiceGrpcTransport,) def test_job_service_base_transport_error(): @@ -5905,13 +5562,15 @@ def test_job_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_job_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.JobServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -5920,27 +5579,27 @@ def test_job_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_custom_job', - 'get_custom_job', - 'list_custom_jobs', - 'delete_custom_job', - 'cancel_custom_job', - 'create_data_labeling_job', - 'get_data_labeling_job', - 'list_data_labeling_jobs', - 'delete_data_labeling_job', - 'cancel_data_labeling_job', - 'create_hyperparameter_tuning_job', - 'get_hyperparameter_tuning_job', - 'list_hyperparameter_tuning_jobs', - 'delete_hyperparameter_tuning_job', - 'cancel_hyperparameter_tuning_job', - 'create_batch_prediction_job', - 'get_batch_prediction_job', - 'list_batch_prediction_jobs', - 'delete_batch_prediction_job', - 'cancel_batch_prediction_job', - ) + "create_custom_job", + "get_custom_job", + "list_custom_jobs", + "delete_custom_job", + "cancel_custom_job", + "create_data_labeling_job", + "get_data_labeling_job", + "list_data_labeling_jobs", + "delete_data_labeling_job", + "cancel_data_labeling_job", + "create_hyperparameter_tuning_job", + "get_hyperparameter_tuning_job", + "list_hyperparameter_tuning_jobs", + "delete_hyperparameter_tuning_job", + "cancel_hyperparameter_tuning_job", + "create_batch_prediction_job", + "get_batch_prediction_job", + "list_batch_prediction_jobs", + "delete_batch_prediction_job", + "cancel_batch_prediction_job", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -5953,23 +5612,28 @@ def test_job_service_base_transport(): def test_job_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_job_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.JobServiceTransport() @@ -5978,11 +5642,11 @@ def test_job_service_base_transport_with_adc(): def test_job_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) JobServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -5990,37 +5654,43 @@ def test_job_service_auth_adc(): def test_job_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.JobServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_job_service_host_with_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_job_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -6028,24 +5698,28 @@ def test_job_service_grpc_transport_channel(): def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -6054,7 +5728,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -6070,9 +5744,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -6080,17 +5752,20 @@ def test_job_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize("transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport]) -def test_job_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -6107,9 +5782,7 @@ def test_job_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -6118,16 +5791,12 @@ def test_job_service_transport_channel_mtls_with_adc( def test_job_service_grpc_lro_client(): client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -6135,36 +5804,36 @@ def test_job_service_grpc_lro_client(): def test_job_service_grpc_lro_async_client(): client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_batch_prediction_job_path(): project = "squid" location = "clam" batch_prediction_job = "whelk" - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format(project=project, location=location, batch_prediction_job=batch_prediction_job, ) - actual = JobServiceClient.batch_prediction_job_path(project, location, batch_prediction_job) + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, location=location, batch_prediction_job=batch_prediction_job, + ) + actual = JobServiceClient.batch_prediction_job_path( + project, location, batch_prediction_job + ) assert expected == actual def test_parse_batch_prediction_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", } path = JobServiceClient.batch_prediction_job_path(**expected) @@ -6172,22 +5841,24 @@ def test_parse_batch_prediction_job_path(): actual = JobServiceClient.parse_batch_prediction_job_path(path) assert expected == actual + def test_custom_job_path(): project = "cuttlefish" location = "mussel" custom_job = "winkle" - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format(project=project, location=location, custom_job=custom_job, ) + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) actual = JobServiceClient.custom_job_path(project, location, custom_job) assert expected == actual def test_parse_custom_job_path(): expected = { - "project": "nautilus", - "location": "scallop", - "custom_job": "abalone", - + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", } path = JobServiceClient.custom_job_path(**expected) @@ -6195,22 +5866,26 @@ def test_parse_custom_job_path(): actual = JobServiceClient.parse_custom_job_path(path) assert expected == actual + def test_data_labeling_job_path(): project = "squid" location = "clam" data_labeling_job = "whelk" - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format(project=project, location=location, data_labeling_job=data_labeling_job, ) - actual = JobServiceClient.data_labeling_job_path(project, location, data_labeling_job) + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + actual = JobServiceClient.data_labeling_job_path( + project, location, data_labeling_job + ) assert expected == actual def test_parse_data_labeling_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", } path = JobServiceClient.data_labeling_job_path(**expected) @@ -6218,22 +5893,24 @@ def test_parse_data_labeling_job_path(): actual = JobServiceClient.parse_data_labeling_job_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = JobServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = JobServiceClient.dataset_path(**expected) @@ -6241,22 +5918,28 @@ def test_parse_dataset_path(): actual = JobServiceClient.parse_dataset_path(path) assert expected == actual + def test_hyperparameter_tuning_job_path(): project = "squid" location = "clam" hyperparameter_tuning_job = "whelk" - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format(project=project, location=location, hyperparameter_tuning_job=hyperparameter_tuning_job, ) - actual = JobServiceClient.hyperparameter_tuning_job_path(project, location, hyperparameter_tuning_job) + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + actual = JobServiceClient.hyperparameter_tuning_job_path( + project, location, hyperparameter_tuning_job + ) assert expected == actual def test_parse_hyperparameter_tuning_job_path(): expected = { - "project": "octopus", - "location": "oyster", - "hyperparameter_tuning_job": "nudibranch", - + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", } path = JobServiceClient.hyperparameter_tuning_job_path(**expected) @@ -6264,22 +5947,24 @@ def test_parse_hyperparameter_tuning_job_path(): actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = JobServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = JobServiceClient.model_path(**expected) @@ -6287,18 +5972,20 @@ def test_parse_model_path(): actual = JobServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = JobServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = JobServiceClient.common_billing_account_path(**expected) @@ -6306,18 +5993,18 @@ def test_parse_common_billing_account_path(): actual = JobServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = JobServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = JobServiceClient.common_folder_path(**expected) @@ -6325,18 +6012,18 @@ def test_parse_common_folder_path(): actual = JobServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = JobServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = JobServiceClient.common_organization_path(**expected) @@ -6344,18 +6031,18 @@ def test_parse_common_organization_path(): actual = JobServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = JobServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = JobServiceClient.common_project_path(**expected) @@ -6363,20 +6050,22 @@ def test_parse_common_project_path(): actual = JobServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = JobServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = JobServiceClient.common_location_path(**expected) @@ -6388,17 +6077,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.JobServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = JobServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index f2c4c7cda9..85e6a2d362 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.migration_service import MigrationServiceClient +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceClient, +) from google.cloud.aiplatform_v1beta1.services.migration_service import pagers from google.cloud.aiplatform_v1beta1.services.migration_service import transports from google.cloud.aiplatform_v1beta1.types import migratable_resource @@ -53,7 +57,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -64,17 +72,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert MigrationServiceClient._get_default_mtls_endpoint(None) is None - assert MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [MigrationServiceClient, MigrationServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient] +) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -82,7 +109,7 @@ def test_migration_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_client_get_transport_class(): @@ -93,29 +120,44 @@ def test_migration_service_client_get_transport_class(): assert transport == transports.MigrationServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) -def test_migration_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(MigrationServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -131,7 +173,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -147,7 +189,7 @@ def test_migration_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -167,13 +209,15 @@ def test_migration_service_client_client_options(client_class, transport_class, client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -186,26 +230,66 @@ def test_migration_service_client_client_options(client_class, transport_class, client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "true"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc", "false"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(MigrationServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceClient)) -@mock.patch.object(MigrationServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(MigrationServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_migration_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -228,11 +312,21 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -242,7 +336,9 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -257,10 +353,17 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -275,16 +378,23 @@ def test_migration_service_client_mtls_env_auto(client_class, transport_class, t ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_migration_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -297,16 +407,24 @@ def test_migration_service_client_client_options_scopes(client_class, transport_ client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), - (MigrationServiceAsyncClient, transports.MigrationServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_migration_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -321,10 +439,12 @@ def test_migration_service_client_client_options_credentials_file(client_class, def test_migration_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = MigrationServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -337,10 +457,12 @@ def test_migration_service_client_client_options_from_dict(): ) -def test_search_migratable_resources(transport: str = 'grpc', request_type=migration_service.SearchMigratableResourcesRequest): +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -349,12 +471,11 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.search_migratable_resources(request) @@ -369,7 +490,7 @@ def test_search_migratable_resources(transport: str = 'grpc', request_type=migra assert isinstance(response, pagers.SearchMigratableResourcesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_search_migratable_resources_from_dict(): @@ -377,10 +498,12 @@ def test_search_migratable_resources_from_dict(): @pytest.mark.asyncio -async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.SearchMigratableResourcesRequest): +async def test_search_migratable_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.SearchMigratableResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -389,12 +512,14 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.search_migratable_resources(request) @@ -407,7 +532,7 @@ async def test_search_migratable_resources_async(transport: str = 'grpc_asyncio' # Establish that the response is the type that we expect. assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -416,19 +541,17 @@ async def test_search_migratable_resources_async_from_dict(): def test_search_migratable_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: call.return_value = migration_service.SearchMigratableResourcesResponse() client.search_migratable_resources(request) @@ -440,10 +563,7 @@ def test_search_migratable_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -455,13 +575,15 @@ async def test_search_migratable_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.SearchMigratableResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) await client.search_migratable_resources(request) @@ -472,49 +594,39 @@ async def test_search_migratable_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_search_migratable_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.search_migratable_resources( - parent='parent_value', - ) + client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_search_migratable_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) @@ -526,24 +638,24 @@ async def test_search_migratable_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = migration_service.SearchMigratableResourcesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(migration_service.SearchMigratableResourcesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.search_migratable_resources( - parent='parent_value', - ) + response = await client.search_migratable_resources(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -556,20 +668,17 @@ async def test_search_migratable_resources_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.search_migratable_resources( - migration_service.SearchMigratableResourcesRequest(), - parent='parent_value', + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", ) def test_search_migratable_resources_pager(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -578,17 +687,14 @@ def test_search_migratable_resources_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -601,9 +707,7 @@ def test_search_migratable_resources_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.search_migratable_resources(request={}) @@ -611,18 +715,18 @@ def test_search_migratable_resources_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in results) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + def test_search_migratable_resources_pages(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__') as call: + type(client.transport.search_migratable_resources), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -631,17 +735,14 @@ def test_search_migratable_resources_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -652,19 +753,20 @@ def test_search_migratable_resources_pages(): RuntimeError, ) pages = list(client.search_migratable_resources(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_search_migratable_resources_async_pager(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -673,17 +775,14 @@ async def test_search_migratable_resources_async_pager(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -694,25 +793,27 @@ async def test_search_migratable_resources_async_pager(): RuntimeError, ) async_pager = await client.search_migratable_resources(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, migratable_resource.MigratableResource) - for i in responses) + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + @pytest.mark.asyncio async def test_search_migratable_resources_async_pages(): - client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.search_migratable_resources), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( migration_service.SearchMigratableResourcesResponse( @@ -721,17 +822,14 @@ async def test_search_migratable_resources_async_pages(): migratable_resource.MigratableResource(), migratable_resource.MigratableResource(), ], - next_page_token='abc', + next_page_token="abc", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[], - next_page_token='def', + migratable_resources=[], next_page_token="def", ), migration_service.SearchMigratableResourcesResponse( - migratable_resources=[ - migratable_resource.MigratableResource(), - ], - next_page_token='ghi', + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", ), migration_service.SearchMigratableResourcesResponse( migratable_resources=[ @@ -744,14 +842,15 @@ async def test_search_migratable_resources_async_pages(): pages = [] async for page_ in (await client.search_migratable_resources(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration_service.BatchMigrateResourcesRequest): +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -760,10 +859,10 @@ def test_batch_migrate_resources(transport: str = 'grpc', request_type=migration # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.batch_migrate_resources(request) @@ -782,10 +881,12 @@ def test_batch_migrate_resources_from_dict(): @pytest.mark.asyncio -async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', request_type=migration_service.BatchMigrateResourcesRequest): +async def test_batch_migrate_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.BatchMigrateResourcesRequest, +): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -794,11 +895,11 @@ async def test_batch_migrate_resources_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.batch_migrate_resources(request) @@ -819,20 +920,18 @@ async def test_batch_migrate_resources_async_from_dict(): def test_batch_migrate_resources_field_headers(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.batch_migrate_resources(request) @@ -843,10 +942,7 @@ def test_batch_migrate_resources_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -858,13 +954,15 @@ async def test_batch_migrate_resources_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = migration_service.BatchMigrateResourcesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.batch_migrate_resources(request) @@ -875,29 +973,30 @@ async def test_batch_migrate_resources_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_batch_migrate_resources_flattened(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -905,23 +1004,33 @@ def test_batch_migrate_resources_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] def test_batch_migrate_resources_flattened_error(): - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -933,19 +1042,25 @@ async def test_batch_migrate_resources_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.batch_migrate_resources), - '__call__') as call: + type(client.transport.batch_migrate_resources), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_migrate_resources( - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) # Establish that the underlying call was made with the expected @@ -953,9 +1068,15 @@ async def test_batch_migrate_resources_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].migrate_resource_requests == [migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))] + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] @pytest.mark.asyncio @@ -969,8 +1090,14 @@ async def test_batch_migrate_resources_flattened_error_async(): with pytest.raises(ValueError): await client.batch_migrate_resources( migration_service.BatchMigrateResourcesRequest(), - parent='parent_value', - migrate_resource_requests=[migration_service.MigrateResourceRequest(migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig(endpoint='endpoint_value'))], + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], ) @@ -981,8 +1108,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1001,8 +1127,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = MigrationServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1030,13 +1155,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.MigrationServiceGrpcTransport, - transports.MigrationServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1044,13 +1172,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.MigrationServiceGrpcTransport, - ) + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) def test_migration_service_base_transport_error(): @@ -1058,13 +1181,15 @@ def test_migration_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_migration_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.MigrationServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1073,9 +1198,9 @@ def test_migration_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'search_migratable_resources', - 'batch_migrate_resources', - ) + "search_migratable_resources", + "batch_migrate_resources", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1088,23 +1213,28 @@ def test_migration_service_base_transport(): def test_migration_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_migration_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.MigrationServiceTransport() @@ -1113,11 +1243,11 @@ def test_migration_service_base_transport_with_adc(): def test_migration_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) MigrationServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1125,37 +1255,43 @@ def test_migration_service_auth_adc(): def test_migration_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.MigrationServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_migration_service_host_with_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_migration_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1163,24 +1299,33 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) def test_migration_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1189,7 +1334,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1205,9 +1350,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1215,17 +1358,23 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize("transport_class", [transports.MigrationServiceGrpcTransport, transports.MigrationServiceGrpcAsyncIOTransport]) -def test_migration_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1242,9 +1391,7 @@ def test_migration_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1253,16 +1400,12 @@ def test_migration_service_transport_channel_mtls_with_adc( def test_migration_service_grpc_lro_client(): client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1270,36 +1413,36 @@ def test_migration_service_grpc_lro_client(): def test_migration_service_grpc_lro_async_client(): client = MigrationServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_annotated_dataset_path(): project = "squid" dataset = "clam" annotated_dataset = "whelk" - expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format(project=project, dataset=dataset, annotated_dataset=annotated_dataset, ) - actual = MigrationServiceClient.annotated_dataset_path(project, dataset, annotated_dataset) + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) assert expected == actual def test_parse_annotated_dataset_path(): expected = { - "project": "octopus", - "dataset": "oyster", - "annotated_dataset": "nudibranch", - + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", } path = MigrationServiceClient.annotated_dataset_path(**expected) @@ -1307,22 +1450,24 @@ def test_parse_annotated_dataset_path(): actual = MigrationServiceClient.parse_annotated_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "cuttlefish" location = "mussel" dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", - + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", } path = MigrationServiceClient.dataset_path(**expected) @@ -1330,20 +1475,22 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "squid" dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format(project=project, dataset=dataset, ) + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", - "dataset": "octopus", - + "project": "whelk", + "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1351,22 +1498,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_dataset_path(): project = "oyster" location = "nudibranch" dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(project=project, location=location, dataset=dataset, ) + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", - "dataset": "nautilus", - + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1374,22 +1523,24 @@ def test_parse_dataset_path(): actual = MigrationServiceClient.parse_dataset_path(path) assert expected == actual + def test_model_path(): project = "scallop" location = "abalone" model = "squid" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - + "project": "clam", + "location": "whelk", + "model": "octopus", } path = MigrationServiceClient.model_path(**expected) @@ -1397,22 +1548,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_model_path(): project = "oyster" location = "nudibranch" model = "cuttlefish" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = MigrationServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "mussel", - "location": "winkle", - "model": "nautilus", - + "project": "mussel", + "location": "winkle", + "model": "nautilus", } path = MigrationServiceClient.model_path(**expected) @@ -1420,22 +1573,24 @@ def test_parse_model_path(): actual = MigrationServiceClient.parse_model_path(path) assert expected == actual + def test_version_path(): project = "scallop" model = "abalone" version = "squid" - expected = "projects/{project}/models/{model}/versions/{version}".format(project=project, model=model, version=version, ) + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) actual = MigrationServiceClient.version_path(project, model, version) assert expected == actual def test_parse_version_path(): expected = { - "project": "clam", - "model": "whelk", - "version": "octopus", - + "project": "clam", + "model": "whelk", + "version": "octopus", } path = MigrationServiceClient.version_path(**expected) @@ -1443,18 +1598,20 @@ def test_parse_version_path(): actual = MigrationServiceClient.parse_version_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = MigrationServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "nudibranch", - + "billing_account": "nudibranch", } path = MigrationServiceClient.common_billing_account_path(**expected) @@ -1462,18 +1619,18 @@ def test_parse_common_billing_account_path(): actual = MigrationServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = MigrationServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "mussel", - + "folder": "mussel", } path = MigrationServiceClient.common_folder_path(**expected) @@ -1481,18 +1638,18 @@ def test_parse_common_folder_path(): actual = MigrationServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = MigrationServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nautilus", - + "organization": "nautilus", } path = MigrationServiceClient.common_organization_path(**expected) @@ -1500,18 +1657,18 @@ def test_parse_common_organization_path(): actual = MigrationServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = MigrationServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "abalone", - + "project": "abalone", } path = MigrationServiceClient.common_project_path(**expected) @@ -1519,20 +1676,22 @@ def test_parse_common_project_path(): actual = MigrationServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = MigrationServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "whelk", - "location": "octopus", - + "project": "whelk", + "location": "octopus", } path = MigrationServiceClient.common_location_path(**expected) @@ -1544,17 +1703,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: client = MigrationServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.MigrationServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = MigrationServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index af1e117cc3..d05698a46a 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -35,7 +35,9 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.model_service import ( + ModelServiceAsyncClient, +) from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.services.model_service import transports @@ -65,7 +67,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -76,17 +82,30 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi @pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -94,7 +113,7 @@ def test_model_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_client_get_transport_class(): @@ -105,29 +124,42 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -def test_model_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -143,7 +175,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -159,7 +191,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -179,13 +211,15 @@ def test_model_service_client_client_options(client_class, transport_class, tran client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -198,26 +232,54 @@ def test_model_service_client_client_options(client_class, transport_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -240,11 +302,21 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -254,7 +326,9 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -269,10 +343,17 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -287,16 +368,23 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -309,16 +397,24 @@ def test_model_service_client_client_options_scopes(client_class, transport_clas client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -333,11 +429,11 @@ def test_model_service_client_client_options_credentials_file(client_class, tran def test_model_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -349,10 +445,11 @@ def test_model_service_client_client_options_from_dict(): ) -def test_upload_model(transport: str = 'grpc', request_type=model_service.UploadModelRequest): +def test_upload_model( + transport: str = "grpc", request_type=model_service.UploadModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -360,11 +457,9 @@ def test_upload_model(transport: str = 'grpc', request_type=model_service.Upload request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.upload_model(request) @@ -383,10 +478,11 @@ def test_upload_model_from_dict(): @pytest.mark.asyncio -async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UploadModelRequest): +async def test_upload_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -394,12 +490,10 @@ async def test_upload_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.upload_model(request) @@ -420,20 +514,16 @@ async def test_upload_model_async_from_dict(): def test_upload_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.upload_model(request) @@ -444,28 +534,23 @@ def test_upload_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UploadModelRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.upload_model(request) @@ -476,29 +561,21 @@ async def test_upload_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_upload_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -506,47 +583,40 @@ def test_upload_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") def test_upload_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) @pytest.mark.asyncio async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.upload_model), - '__call__') as call: + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.upload_model( - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", model=gca_model.Model(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -554,31 +624,28 @@ async def test_upload_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") @pytest.mark.asyncio async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.upload_model( model_service.UploadModelRequest(), - parent='parent_value', - model=gca_model.Model(name='name_value'), + parent="parent_value", + model=gca_model.Model(name="name_value"), ) -def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelRequest): +def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -586,31 +653,21 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.get_model(request) @@ -625,25 +682,31 @@ def test_get_model(transport: str = 'grpc', request_type=model_service.GetModelR assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_get_model_from_dict(): @@ -651,10 +714,11 @@ def test_get_model_from_dict(): @pytest.mark.asyncio -async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): +async def test_get_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -662,22 +726,28 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.get_model(request) @@ -690,25 +760,31 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -717,19 +793,15 @@ async def test_get_model_async_from_dict(): def test_get_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = model.Model() client.get_model(request) @@ -741,27 +813,20 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -773,99 +838,79 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model( - name='name_value', - ) + client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model( - name='name_value', - ) + response = await client.get_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model( - model_service.GetModelRequest(), - name='name_value', + model_service.GetModelRequest(), name="name_value", ) -def test_list_models(transport: str = 'grpc', request_type=model_service.ListModelsRequest): +def test_list_models( + transport: str = "grpc", request_type=model_service.ListModelsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -873,13 +918,10 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_models(request) @@ -894,7 +936,7 @@ def test_list_models(transport: str = 'grpc', request_type=model_service.ListMod assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_models_from_dict(): @@ -902,10 +944,11 @@ def test_list_models_from_dict(): @pytest.mark.asyncio -async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): +async def test_list_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -913,13 +956,11 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse(next_page_token="next_page_token_value",) + ) response = await client.list_models(request) @@ -932,7 +973,7 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -941,19 +982,15 @@ async def test_list_models_async_from_dict(): def test_list_models_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: call.return_value = model_service.ListModelsResponse() client.list_models(request) @@ -965,28 +1002,23 @@ def test_list_models_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) await client.list_models(request) @@ -997,138 +1029,98 @@ async def test_list_models_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_models_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_models( - parent='parent_value', - ) + client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_models_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_models( - parent='parent_value', - ) + response = await client.list_models(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_models( - model_service.ListModelsRequest(), - parent='parent_value', + model_service.ListModelsRequest(), parent="parent_value", ) def test_list_models_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_models(request={}) @@ -1136,147 +1128,96 @@ def test_list_models_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) + assert all(isinstance(i, model.Model) for i in results) + def test_list_models_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_models_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) - for i in responses) + assert all(isinstance(i, model.Model) for i in responses) + @pytest.mark.asyncio async def test_list_models_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", ), + model_service.ListModelsResponse(models=[], next_page_token="def",), model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], + models=[model.Model(),], next_page_token="ghi", ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), RuntimeError, ) pages = [] async for page_ in (await client.list_models(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_update_model(transport: str = 'grpc', request_type=model_service.UpdateModelRequest): +def test_update_model( + transport: str = "grpc", request_type=model_service.UpdateModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1284,31 +1225,21 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model( - name='name_value', - - display_name='display_name_value', - - description='description_value', - - metadata_schema_uri='metadata_schema_uri_value', - - training_pipeline='training_pipeline_value', - - artifact_uri='artifact_uri_value', - - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - - supported_input_storage_formats=['supported_input_storage_formats_value'], - - supported_output_storage_formats=['supported_output_storage_formats_value'], - - etag='etag_value', - + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", ) response = client.update_model(request) @@ -1323,25 +1254,31 @@ def test_update_model(transport: str = 'grpc', request_type=model_service.Update assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" def test_update_model_from_dict(): @@ -1349,10 +1286,11 @@ def test_update_model_from_dict(): @pytest.mark.asyncio -async def test_update_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateModelRequest): +async def test_update_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1360,22 +1298,28 @@ async def test_update_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model( - name='name_value', - display_name='display_name_value', - description='description_value', - metadata_schema_uri='metadata_schema_uri_value', - training_pipeline='training_pipeline_value', - artifact_uri='artifact_uri_value', - supported_deployment_resources_types=[gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES], - supported_input_storage_formats=['supported_input_storage_formats_value'], - supported_output_storage_formats=['supported_output_storage_formats_value'], - etag='etag_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) response = await client.update_model(request) @@ -1388,25 +1332,31 @@ async def test_update_model_async(transport: str = 'grpc_asyncio', request_type= # Establish that the response is the type that we expect. assert isinstance(response, gca_model.Model) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.description == 'description_value' + assert response.description == "description_value" - assert response.metadata_schema_uri == 'metadata_schema_uri_value' + assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == 'training_pipeline_value' + assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == 'artifact_uri_value' + assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES] + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] - assert response.supported_input_storage_formats == ['supported_input_storage_formats_value'] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] - assert response.supported_output_storage_formats == ['supported_output_storage_formats_value'] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] - assert response.etag == 'etag_value' + assert response.etag == "etag_value" @pytest.mark.asyncio @@ -1415,19 +1365,15 @@ async def test_update_model_async_from_dict(): def test_update_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = gca_model.Model() client.update_model(request) @@ -1439,27 +1385,20 @@ def test_update_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.UpdateModelRequest() - request.model.name = 'model.name/value' + request.model.name = "model.name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) await client.update_model(request) @@ -1471,29 +1410,22 @@ async def test_update_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model.name=model.name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] def test_update_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1501,36 +1433,30 @@ def test_update_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @pytest.mark.asyncio async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.update_model), - '__call__') as call: + with mock.patch.object(type(client.transport.update_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = gca_model.Model() @@ -1538,8 +1464,8 @@ async def test_update_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_model( - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1547,31 +1473,30 @@ async def test_update_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name='name_value') + assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.update_model( model_service.UpdateModelRequest(), - model=gca_model.Model(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) -def test_delete_model(transport: str = 'grpc', request_type=model_service.DeleteModelRequest): +def test_delete_model( + transport: str = "grpc", request_type=model_service.DeleteModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1579,11 +1504,9 @@ def test_delete_model(transport: str = 'grpc', request_type=model_service.Delete request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_model(request) @@ -1602,10 +1525,11 @@ def test_delete_model_from_dict(): @pytest.mark.asyncio -async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteModelRequest): +async def test_delete_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1613,12 +1537,10 @@ async def test_delete_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_model(request) @@ -1639,20 +1561,16 @@ async def test_delete_model_async_from_dict(): def test_delete_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_model(request) @@ -1663,28 +1581,23 @@ def test_delete_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.DeleteModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_model(request) @@ -1695,101 +1608,81 @@ async def test_delete_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_model( - name='name_value', - ) + client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.delete_model), - '__call__') as call: + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_model( - name='name_value', - ) + response = await client.delete_model(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_model( - model_service.DeleteModelRequest(), - name='name_value', + model_service.DeleteModelRequest(), name="name_value", ) -def test_export_model(transport: str = 'grpc', request_type=model_service.ExportModelRequest): +def test_export_model( + transport: str = "grpc", request_type=model_service.ExportModelRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1797,11 +1690,9 @@ def test_export_model(transport: str = 'grpc', request_type=model_service.Export request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.export_model(request) @@ -1820,10 +1711,11 @@ def test_export_model_from_dict(): @pytest.mark.asyncio -async def test_export_model_async(transport: str = 'grpc_asyncio', request_type=model_service.ExportModelRequest): +async def test_export_model_async( + transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1831,12 +1723,10 @@ async def test_export_model_async(transport: str = 'grpc_asyncio', request_type= request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.export_model(request) @@ -1857,20 +1747,16 @@ async def test_export_model_async_from_dict(): def test_export_model_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.export_model(request) @@ -1881,28 +1767,23 @@ def test_export_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ExportModelRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.export_model(request) @@ -1913,29 +1794,24 @@ async def test_export_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] - + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] -def test_export_model_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + +def test_export_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -1943,47 +1819,47 @@ def test_export_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) def test_export_model_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) @pytest.mark.asyncio async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.export_model), - '__call__') as call: + with mock.patch.object(type(client.transport.export_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.export_model( - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) # Establish that the underlying call was made with the expected @@ -1991,31 +1867,34 @@ async def test_export_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value') + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) @pytest.mark.asyncio async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.export_model( model_service.ExportModelRequest(), - name='name_value', - output_config=model_service.ExportModelRequest.OutputConfig(export_format_id='export_format_id_value'), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), ) -def test_get_model_evaluation(transport: str = 'grpc', request_type=model_service.GetModelEvaluationRequest): +def test_get_model_evaluation( + transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2024,16 +1903,13 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - - slice_dimensions=['slice_dimensions_value'], - + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], ) response = client.get_model_evaluation(request) @@ -2048,11 +1924,11 @@ def test_get_model_evaluation(transport: str = 'grpc', request_type=model_servic assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] def test_get_model_evaluation_from_dict(): @@ -2060,10 +1936,12 @@ def test_get_model_evaluation_from_dict(): @pytest.mark.asyncio -async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationRequest): +async def test_get_model_evaluation_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2072,14 +1950,16 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - slice_dimensions=['slice_dimensions_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + ) response = await client.get_model_evaluation(request) @@ -2092,11 +1972,11 @@ async def test_get_model_evaluation_async(transport: str = 'grpc_asyncio', reque # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ['slice_dimensions_value'] + assert response.slice_dimensions == ["slice_dimensions_value"] @pytest.mark.asyncio @@ -2105,19 +1985,17 @@ async def test_get_model_evaluation_async_from_dict(): def test_get_model_evaluation_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: call.return_value = model_evaluation.ModelEvaluation() client.get_model_evaluation(request) @@ -2129,28 +2007,25 @@ def test_get_model_evaluation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) await client.get_model_evaluation(request) @@ -2161,99 +2036,85 @@ async def test_get_model_evaluation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation( - name='name_value', - ) + client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation), - '__call__') as call: + type(client.transport.get_model_evaluation), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation.ModelEvaluation() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation.ModelEvaluation()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation( - name='name_value', - ) + response = await client.get_model_evaluation(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), - name='name_value', + model_service.GetModelEvaluationRequest(), name="name_value", ) -def test_list_model_evaluations(transport: str = 'grpc', request_type=model_service.ListModelEvaluationsRequest): +def test_list_model_evaluations( + transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2262,12 +2123,11 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluations(request) @@ -2282,7 +2142,7 @@ def test_list_model_evaluations(transport: str = 'grpc', request_type=model_serv assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluations_from_dict(): @@ -2290,10 +2150,12 @@ def test_list_model_evaluations_from_dict(): @pytest.mark.asyncio -async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationsRequest): +async def test_list_model_evaluations_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationsRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2302,12 +2164,14 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluations(request) @@ -2320,7 +2184,7 @@ async def test_list_model_evaluations_async(transport: str = 'grpc_asyncio', req # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2329,19 +2193,17 @@ async def test_list_model_evaluations_async_from_dict(): def test_list_model_evaluations_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationsResponse() client.list_model_evaluations(request) @@ -2353,28 +2215,25 @@ def test_list_model_evaluations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) await client.list_model_evaluations(request) @@ -2385,104 +2244,87 @@ async def test_list_model_evaluations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluations_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluations( - parent='parent_value', - ) + client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluations( - parent='parent_value', - ) + response = await client.list_model_evaluations(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), - parent='parent_value', + model_service.ListModelEvaluationsRequest(), parent="parent_value", ) def test_list_model_evaluations_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2491,17 +2333,14 @@ def test_list_model_evaluations_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2514,9 +2353,7 @@ def test_list_model_evaluations_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluations(request={}) @@ -2524,18 +2361,16 @@ def test_list_model_evaluations_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in results) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) + def test_list_model_evaluations_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__') as call: + type(client.transport.list_model_evaluations), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2544,17 +2379,14 @@ def test_list_model_evaluations_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2565,19 +2397,20 @@ def test_list_model_evaluations_pages(): RuntimeError, ) pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2586,17 +2419,14 @@ async def test_list_model_evaluations_async_pager(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2607,25 +2437,25 @@ async def test_list_model_evaluations_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) - for i in responses) + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) + @pytest.mark.asyncio async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluations), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationsResponse( @@ -2634,17 +2464,14 @@ async def test_list_model_evaluations_async_pages(): model_evaluation.ModelEvaluation(), model_evaluation.ModelEvaluation(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[], - next_page_token='def', + model_evaluations=[], next_page_token="def", ), model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - ], - next_page_token='ghi', + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", ), model_service.ListModelEvaluationsResponse( model_evaluations=[ @@ -2657,14 +2484,15 @@ async def test_list_model_evaluations_async_pages(): pages = [] async for page_ in (await client.list_model_evaluations(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_service.GetModelEvaluationSliceRequest): +def test_get_model_evaluation_slice( + transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2673,14 +2501,11 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - - metrics_schema_uri='metrics_schema_uri_value', - + name="name_value", metrics_schema_uri="metrics_schema_uri_value", ) response = client.get_model_evaluation_slice(request) @@ -2695,9 +2520,9 @@ def test_get_model_evaluation_slice(transport: str = 'grpc', request_type=model_ assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" def test_get_model_evaluation_slice_from_dict(): @@ -2705,10 +2530,12 @@ def test_get_model_evaluation_slice_from_dict(): @pytest.mark.asyncio -async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelEvaluationSliceRequest): +async def test_get_model_evaluation_slice_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationSliceRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2717,13 +2544,14 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice( - name='name_value', - metrics_schema_uri='metrics_schema_uri_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + ) response = await client.get_model_evaluation_slice(request) @@ -2736,9 +2564,9 @@ async def test_get_model_evaluation_slice_async(transport: str = 'grpc_asyncio', # Establish that the response is the type that we expect. assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.metrics_schema_uri == 'metrics_schema_uri_value' + assert response.metrics_schema_uri == "metrics_schema_uri_value" @pytest.mark.asyncio @@ -2747,19 +2575,17 @@ async def test_get_model_evaluation_slice_async_from_dict(): def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: call.return_value = model_evaluation_slice.ModelEvaluationSlice() client.get_model_evaluation_slice(request) @@ -2771,28 +2597,25 @@ def test_get_model_evaluation_slice_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.GetModelEvaluationSliceRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) await client.get_model_evaluation_slice(request) @@ -2803,99 +2626,85 @@ async def test_get_model_evaluation_slice_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_model_evaluation_slice( - name='name_value', - ) + client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_model_evaluation_slice), - '__call__') as call: + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_evaluation_slice.ModelEvaluationSlice() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_evaluation_slice.ModelEvaluationSlice()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice( - name='name_value', - ) + response = await client.get_model_evaluation_slice(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), - name='name_value', + model_service.GetModelEvaluationSliceRequest(), name="name_value", ) -def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=model_service.ListModelEvaluationSlicesRequest): +def test_list_model_evaluation_slices( + transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest +): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2904,12 +2713,11 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_model_evaluation_slices(request) @@ -2924,7 +2732,7 @@ def test_list_model_evaluation_slices(transport: str = 'grpc', request_type=mode assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_model_evaluation_slices_from_dict(): @@ -2932,10 +2740,12 @@ def test_list_model_evaluation_slices_from_dict(): @pytest.mark.asyncio -async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelEvaluationSlicesRequest): +async def test_list_model_evaluation_slices_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationSlicesRequest, +): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -2944,12 +2754,14 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_model_evaluation_slices(request) @@ -2962,7 +2774,7 @@ async def test_list_model_evaluation_slices_async(transport: str = 'grpc_asyncio # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -2971,19 +2783,17 @@ async def test_list_model_evaluation_slices_async_from_dict(): def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: call.return_value = model_service.ListModelEvaluationSlicesResponse() client.list_model_evaluation_slices(request) @@ -2995,28 +2805,25 @@ def test_list_model_evaluation_slices_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = model_service.ListModelEvaluationSlicesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) await client.list_model_evaluation_slices(request) @@ -3027,104 +2834,87 @@ async def test_list_model_evaluation_slices_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_model_evaluation_slices( - parent='parent_value', - ) + client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelEvaluationSlicesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelEvaluationSlicesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices( - parent='parent_value', - ) + response = await client.list_model_evaluation_slices(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), - parent='parent_value', + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", ) def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3133,17 +2923,16 @@ def test_list_model_evaluation_slices_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3156,9 +2945,7 @@ def test_list_model_evaluation_slices_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_model_evaluation_slices(request={}) @@ -3166,18 +2953,18 @@ def test_list_model_evaluation_slices_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in results) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results + ) + def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__') as call: + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3186,17 +2973,16 @@ def test_list_model_evaluation_slices_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3207,19 +2993,20 @@ def test_list_model_evaluation_slices_pages(): RuntimeError, ) pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3228,17 +3015,16 @@ async def test_list_model_evaluation_slices_async_pager(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3249,25 +3035,28 @@ async def test_list_model_evaluation_slices_async_pager(): RuntimeError, ) async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses) + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses + ) + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_model_evaluation_slices), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelEvaluationSlicesResponse( @@ -3276,17 +3065,16 @@ async def test_list_model_evaluation_slices_async_pages(): model_evaluation_slice.ModelEvaluationSlice(), model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], - next_page_token='def', + model_evaluation_slices=[], next_page_token="def", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ model_evaluation_slice.ModelEvaluationSlice(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelEvaluationSlicesResponse( model_evaluation_slices=[ @@ -3297,9 +3085,11 @@ async def test_list_model_evaluation_slices_async_pages(): RuntimeError, ) pages = [] - async for page_ in (await client.list_model_evaluation_slices(request={})).pages: + async for page_ in ( + await client.list_model_evaluation_slices(request={}) + ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -3310,8 +3100,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -3330,8 +3119,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -3359,13 +3147,13 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -3373,13 +3161,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.ModelServiceGrpcTransport, - ) + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) def test_model_service_base_transport_error(): @@ -3387,13 +3170,15 @@ def test_model_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -3402,17 +3187,17 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'upload_model', - 'get_model', - 'list_models', - 'update_model', - 'delete_model', - 'export_model', - 'get_model_evaluation', - 'list_model_evaluations', - 'get_model_evaluation_slice', - 'list_model_evaluation_slices', - ) + "upload_model", + "get_model", + "list_models", + "update_model", + "delete_model", + "export_model", + "get_model_evaluation", + "list_model_evaluations", + "get_model_evaluation_slice", + "list_model_evaluation_slices", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -3425,23 +3210,28 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -3450,11 +3240,11 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) ModelServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -3462,37 +3252,43 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_model_service_host_with_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_model_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -3500,24 +3296,28 @@ def test_model_service_grpc_transport_channel(): def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3526,7 +3326,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3542,9 +3342,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3552,17 +3350,20 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -3579,9 +3380,7 @@ def test_model_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -3590,16 +3389,12 @@ def test_model_service_transport_channel_mtls_with_adc( def test_model_service_grpc_lro_client(): client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -3607,36 +3402,34 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = ModelServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = ModelServiceClient.endpoint_path(**expected) @@ -3644,22 +3437,24 @@ def test_parse_endpoint_path(): actual = ModelServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = ModelServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = ModelServiceClient.model_path(**expected) @@ -3667,24 +3462,28 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual + def test_model_evaluation_path(): project = "squid" location = "clam" model = "whelk" evaluation = "octopus" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format(project=project, location=location, model=model, evaluation=evaluation, ) - actual = ModelServiceClient.model_evaluation_path(project, location, model, evaluation) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + actual = ModelServiceClient.model_evaluation_path( + project, location, model, evaluation + ) assert expected == actual def test_parse_model_evaluation_path(): expected = { - "project": "oyster", - "location": "nudibranch", - "model": "cuttlefish", - "evaluation": "mussel", - + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", } path = ModelServiceClient.model_evaluation_path(**expected) @@ -3692,6 +3491,7 @@ def test_parse_model_evaluation_path(): actual = ModelServiceClient.parse_model_evaluation_path(path) assert expected == actual + def test_model_evaluation_slice_path(): project = "winkle" location = "nautilus" @@ -3699,19 +3499,26 @@ def test_model_evaluation_slice_path(): evaluation = "abalone" slice = "squid" - expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format(project=project, location=location, model=model, evaluation=evaluation, slice=slice, ) - actual = ModelServiceClient.model_evaluation_slice_path(project, location, model, evaluation, slice) + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + actual = ModelServiceClient.model_evaluation_slice_path( + project, location, model, evaluation, slice + ) assert expected == actual def test_parse_model_evaluation_slice_path(): expected = { - "project": "clam", - "location": "whelk", - "model": "octopus", - "evaluation": "oyster", - "slice": "nudibranch", - + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", } path = ModelServiceClient.model_evaluation_slice_path(**expected) @@ -3719,22 +3526,26 @@ def test_parse_model_evaluation_slice_path(): actual = ModelServiceClient.parse_model_evaluation_slice_path(path) assert expected == actual + def test_training_pipeline_path(): project = "cuttlefish" location = "mussel" training_pipeline = "winkle" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = ModelServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = ModelServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "nautilus", - "location": "scallop", - "training_pipeline": "abalone", - + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", } path = ModelServiceClient.training_pipeline_path(**expected) @@ -3742,18 +3553,20 @@ def test_parse_training_pipeline_path(): actual = ModelServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "squid" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", - + "billing_account": "clam", } path = ModelServiceClient.common_billing_account_path(**expected) @@ -3761,18 +3574,18 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "whelk" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "octopus", - + "folder": "octopus", } path = ModelServiceClient.common_folder_path(**expected) @@ -3780,18 +3593,18 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "oyster" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", - + "organization": "nudibranch", } path = ModelServiceClient.common_organization_path(**expected) @@ -3799,18 +3612,18 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "cuttlefish" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = ModelServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "mussel", - + "project": "mussel", } path = ModelServiceClient.common_project_path(**expected) @@ -3818,20 +3631,22 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "winkle" location = "nautilus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", - + "project": "scallop", + "location": "abalone", } path = ModelServiceClient.common_location_path(**expected) @@ -3843,17 +3658,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 7ea561790e..ada82b91c0 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.pipeline_service import PipelineServiceClient +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceClient, +) from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports from google.cloud.aiplatform_v1beta1.types import deployed_model_ref @@ -49,7 +53,9 @@ from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import pipeline_state from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import training_pipeline as gca_training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import any_pb2 as gp_any # type: ignore @@ -67,7 +73,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -78,17 +88,35 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) -@pytest.mark.parametrize("client_class", [PipelineServiceClient, PipelineServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] +) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -96,7 +124,7 @@ def test_pipeline_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_client_get_transport_class(): @@ -107,29 +135,44 @@ def test_pipeline_service_client_get_transport_class(): assert transport == transports.PipelineServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) -def test_pipeline_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -145,7 +188,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -161,7 +204,7 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -181,13 +224,15 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -200,26 +245,66 @@ def test_pipeline_service_client_client_options(client_class, transport_class, t client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "true"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc", "false"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(PipelineServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceClient)) -@mock.patch.object(PipelineServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PipelineServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "true", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "false", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_pipeline_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -242,11 +327,21 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -256,7 +351,9 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -271,10 +368,17 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -289,16 +393,23 @@ def test_pipeline_service_client_mtls_env_auto(client_class, transport_class, tr ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_pipeline_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -311,16 +422,24 @@ def test_pipeline_service_client_client_options_scopes(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - (PipelineServiceAsyncClient, transports.PipelineServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_pipeline_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -335,10 +454,12 @@ def test_pipeline_service_client_client_options_credentials_file(client_class, t def test_pipeline_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = PipelineServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -351,10 +472,11 @@ def test_pipeline_service_client_client_options_from_dict(): ) -def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CreateTrainingPipelineRequest): +def test_create_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -363,18 +485,14 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.create_training_pipeline(request) @@ -389,11 +507,11 @@ def test_create_training_pipeline(transport: str = 'grpc', request_type=pipeline assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -403,10 +521,12 @@ def test_create_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CreateTrainingPipelineRequest): +async def test_create_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CreateTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -415,15 +535,17 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.create_training_pipeline(request) @@ -436,11 +558,11 @@ async def test_create_training_pipeline_async(transport: str = 'grpc_asyncio', r # Establish that the response is the type that we expect. assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -451,19 +573,17 @@ async def test_create_training_pipeline_async_from_dict(): def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: call.return_value = gca_training_pipeline.TrainingPipeline() client.create_training_pipeline(request) @@ -475,28 +595,25 @@ def test_create_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) await client.create_training_pipeline(request) @@ -507,29 +624,24 @@ async def test_create_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -537,45 +649,45 @@ def test_create_training_pipeline_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_training_pipeline), - '__call__') as call: + type(client.transport.create_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gca_training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_training_pipeline( - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -583,31 +695,32 @@ async def test_create_training_pipeline_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline(name='name_value') + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) @pytest.mark.asyncio async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.create_training_pipeline( pipeline_service.CreateTrainingPipelineRequest(), - parent='parent_value', - training_pipeline=gca_training_pipeline.TrainingPipeline(name='name_value'), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), ) -def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.GetTrainingPipelineRequest): +def test_get_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -616,18 +729,14 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline( - name='name_value', - - display_name='display_name_value', - - training_task_definition='training_task_definition_value', - + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) response = client.get_training_pipeline(request) @@ -642,11 +751,11 @@ def test_get_training_pipeline(transport: str = 'grpc', request_type=pipeline_se assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -656,10 +765,12 @@ def test_get_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.GetTrainingPipelineRequest): +async def test_get_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.GetTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -668,15 +779,17 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline( - name='name_value', - display_name='display_name_value', - training_task_definition='training_task_definition_value', - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) response = await client.get_training_pipeline(request) @@ -689,11 +802,11 @@ async def test_get_training_pipeline_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" - assert response.training_task_definition == 'training_task_definition_value' + assert response.training_task_definition == "training_task_definition_value" assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED @@ -704,19 +817,17 @@ async def test_get_training_pipeline_async_from_dict(): def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: call.return_value = training_pipeline.TrainingPipeline() client.get_training_pipeline(request) @@ -728,28 +839,25 @@ def test_get_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.GetTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) await client.get_training_pipeline(request) @@ -760,99 +868,85 @@ async def test_get_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_training_pipeline( - name='name_value', - ) + client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_training_pipeline), - '__call__') as call: + type(client.transport.get_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = training_pipeline.TrainingPipeline() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(training_pipeline.TrainingPipeline()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_training_pipeline( - name='name_value', - ) + response = await client.get_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), - name='name_value', + pipeline_service.GetTrainingPipelineRequest(), name="name_value", ) -def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_service.ListTrainingPipelinesRequest): +def test_list_training_pipelines( + transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -861,12 +955,11 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_training_pipelines(request) @@ -881,7 +974,7 @@ def test_list_training_pipelines(transport: str = 'grpc', request_type=pipeline_ assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_training_pipelines_from_dict(): @@ -889,10 +982,12 @@ def test_list_training_pipelines_from_dict(): @pytest.mark.asyncio -async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.ListTrainingPipelinesRequest): +async def test_list_training_pipelines_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.ListTrainingPipelinesRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -901,12 +996,14 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', re # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_training_pipelines(request) @@ -919,7 +1016,7 @@ async def test_list_training_pipelines_async(transport: str = 'grpc_asyncio', re # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -928,19 +1025,17 @@ async def test_list_training_pipelines_async_from_dict(): def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: call.return_value = pipeline_service.ListTrainingPipelinesResponse() client.list_training_pipelines(request) @@ -952,28 +1047,25 @@ def test_list_training_pipelines_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) await client.list_training_pipelines(request) @@ -984,104 +1076,87 @@ async def test_list_training_pipelines_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_training_pipelines_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_training_pipelines( - parent='parent_value', - ) + client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) @pytest.mark.asyncio async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = pipeline_service.ListTrainingPipelinesResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(pipeline_service.ListTrainingPipelinesResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_training_pipelines( - parent='parent_value', - ) + response = await client.list_training_pipelines(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), - parent='parent_value', + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", ) def test_list_training_pipelines_pager(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1090,17 +1165,14 @@ def test_list_training_pipelines_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1113,9 +1185,7 @@ def test_list_training_pipelines_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_training_pipelines(request={}) @@ -1123,18 +1193,16 @@ def test_list_training_pipelines_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in results) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) + def test_list_training_pipelines_pages(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__') as call: + type(client.transport.list_training_pipelines), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1143,17 +1211,14 @@ def test_list_training_pipelines_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1164,19 +1229,20 @@ def test_list_training_pipelines_pages(): RuntimeError, ) pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1185,17 +1251,14 @@ async def test_list_training_pipelines_async_pager(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1206,25 +1269,25 @@ async def test_list_training_pipelines_async_pager(): RuntimeError, ) async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) - for i in responses) + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) + @pytest.mark.asyncio async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_training_pipelines), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( pipeline_service.ListTrainingPipelinesResponse( @@ -1233,17 +1296,14 @@ async def test_list_training_pipelines_async_pages(): training_pipeline.TrainingPipeline(), training_pipeline.TrainingPipeline(), ], - next_page_token='abc', + next_page_token="abc", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], - next_page_token='def', + training_pipelines=[], next_page_token="def", ), pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - ], - next_page_token='ghi', + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", ), pipeline_service.ListTrainingPipelinesResponse( training_pipelines=[ @@ -1256,14 +1316,15 @@ async def test_list_training_pipelines_async_pages(): pages = [] async for page_ in (await client.list_training_pipelines(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.DeleteTrainingPipelineRequest): +def test_delete_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1272,10 +1333,10 @@ def test_delete_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_training_pipeline(request) @@ -1294,10 +1355,12 @@ def test_delete_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.DeleteTrainingPipelineRequest): +async def test_delete_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.DeleteTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1306,11 +1369,11 @@ async def test_delete_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_training_pipeline(request) @@ -1331,20 +1394,18 @@ async def test_delete_training_pipeline_async_from_dict(): def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_training_pipeline(request) @@ -1355,28 +1416,25 @@ def test_delete_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_training_pipeline(request) @@ -1387,101 +1445,85 @@ async def test_delete_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_training_pipeline( - name='name_value', - ) + client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_training_pipeline), - '__call__') as call: + type(client.transport.delete_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_training_pipeline( - name='name_value', - ) + response = await client.delete_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), - name='name_value', + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", ) -def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline_service.CancelTrainingPipelineRequest): +def test_cancel_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest +): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1490,8 +1532,8 @@ def test_cancel_training_pipeline(transport: str = 'grpc', request_type=pipeline # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1512,10 +1554,12 @@ def test_cancel_training_pipeline_from_dict(): @pytest.mark.asyncio -async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', request_type=pipeline_service.CancelTrainingPipelineRequest): +async def test_cancel_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CancelTrainingPipelineRequest, +): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1524,8 +1568,8 @@ async def test_cancel_training_pipeline_async(transport: str = 'grpc_asyncio', r # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1547,19 +1591,17 @@ async def test_cancel_training_pipeline_async_from_dict(): def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = None client.cancel_training_pipeline(request) @@ -1571,27 +1613,22 @@ def test_cancel_training_pipeline_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = pipeline_service.CancelTrainingPipelineRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.cancel_training_pipeline(request) @@ -1603,92 +1640,75 @@ async def test_cancel_training_pipeline_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.cancel_training_pipeline( - name='name_value', - ) + client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.cancel_training_pipeline), - '__call__') as call: + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.cancel_training_pipeline( - name='name_value', - ) + response = await client.cancel_training_pipeline(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) # Attempting to call a method with both a request object and flattened # fields is an error. with pytest.raises(ValueError): await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), - name='name_value', + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", ) @@ -1699,8 +1719,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1719,8 +1738,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1748,13 +1766,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1762,13 +1783,8 @@ def test_transport_adc(transport_class): def test_transport_grpc_default(): # A client should use the gRPC transport by default. - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.PipelineServiceGrpcTransport, - ) + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) def test_pipeline_service_base_transport_error(): @@ -1776,13 +1792,15 @@ def test_pipeline_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_pipeline_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.PipelineServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1791,12 +1809,12 @@ def test_pipeline_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_training_pipeline', - 'get_training_pipeline', - 'list_training_pipelines', - 'delete_training_pipeline', - 'cancel_training_pipeline', - ) + "create_training_pipeline", + "get_training_pipeline", + "list_training_pipelines", + "delete_training_pipeline", + "cancel_training_pipeline", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1809,23 +1827,28 @@ def test_pipeline_service_base_transport(): def test_pipeline_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_pipeline_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.PipelineServiceTransport() @@ -1834,11 +1857,11 @@ def test_pipeline_service_base_transport_with_adc(): def test_pipeline_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) PipelineServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1846,37 +1869,43 @@ def test_pipeline_service_auth_adc(): def test_pipeline_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_pipeline_service_host_with_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_pipeline_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1884,24 +1913,33 @@ def test_pipeline_service_grpc_transport_channel(): def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1910,7 +1948,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1926,9 +1964,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1936,17 +1972,23 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize("transport_class", [transports.PipelineServiceGrpcTransport, transports.PipelineServiceGrpcAsyncIOTransport]) -def test_pipeline_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1963,9 +2005,7 @@ def test_pipeline_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1974,16 +2014,12 @@ def test_pipeline_service_transport_channel_mtls_with_adc( def test_pipeline_service_grpc_lro_client(): client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1991,36 +2027,34 @@ def test_pipeline_service_grpc_lro_client(): def test_pipeline_service_grpc_lro_async_client(): client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_endpoint_path(): project = "squid" location = "clam" endpoint = "whelk" - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format(project=project, location=location, endpoint=endpoint, ) + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) actual = PipelineServiceClient.endpoint_path(project, location, endpoint) assert expected == actual def test_parse_endpoint_path(): expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", } path = PipelineServiceClient.endpoint_path(**expected) @@ -2028,22 +2062,24 @@ def test_parse_endpoint_path(): actual = PipelineServiceClient.parse_endpoint_path(path) assert expected == actual + def test_model_path(): project = "cuttlefish" location = "mussel" model = "winkle" - expected = "projects/{project}/locations/{location}/models/{model}".format(project=project, location=location, model=model, ) + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) actual = PipelineServiceClient.model_path(project, location, model) assert expected == actual def test_parse_model_path(): expected = { - "project": "nautilus", - "location": "scallop", - "model": "abalone", - + "project": "nautilus", + "location": "scallop", + "model": "abalone", } path = PipelineServiceClient.model_path(**expected) @@ -2051,22 +2087,26 @@ def test_parse_model_path(): actual = PipelineServiceClient.parse_model_path(path) assert expected == actual + def test_training_pipeline_path(): project = "squid" location = "clam" training_pipeline = "whelk" - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format(project=project, location=location, training_pipeline=training_pipeline, ) - actual = PipelineServiceClient.training_pipeline_path(project, location, training_pipeline) + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = PipelineServiceClient.training_pipeline_path( + project, location, training_pipeline + ) assert expected == actual def test_parse_training_pipeline_path(): expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", - + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", } path = PipelineServiceClient.training_pipeline_path(**expected) @@ -2074,18 +2114,20 @@ def test_parse_training_pipeline_path(): actual = PipelineServiceClient.parse_training_pipeline_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = PipelineServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = PipelineServiceClient.common_billing_account_path(**expected) @@ -2093,18 +2135,18 @@ def test_parse_common_billing_account_path(): actual = PipelineServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = PipelineServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = PipelineServiceClient.common_folder_path(**expected) @@ -2112,18 +2154,18 @@ def test_parse_common_folder_path(): actual = PipelineServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = PipelineServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = PipelineServiceClient.common_organization_path(**expected) @@ -2131,18 +2173,18 @@ def test_parse_common_organization_path(): actual = PipelineServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = PipelineServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = PipelineServiceClient.common_project_path(**expected) @@ -2150,20 +2192,22 @@ def test_parse_common_project_path(): actual = PipelineServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = PipelineServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = PipelineServiceClient.common_location_path(**expected) @@ -2175,17 +2219,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.PipelineServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = PipelineServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index a9a2977768..6c1061d588 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -35,8 +35,12 @@ from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import SpecialistPoolServiceClient +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceClient, +) from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -56,7 +60,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -67,17 +75,36 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) -@pytest.mark.parametrize("client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient]) +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] +) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds @@ -85,7 +112,7 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_client_get_transport_class(): @@ -96,29 +123,48 @@ def test_specialist_pool_service_client_get_transport_class(): assert transport == transports.SpecialistPoolServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) -def test_specialist_pool_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=credentials.AnonymousCredentials() - ) + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -134,7 +180,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -150,7 +196,7 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() patched.assert_called_once_with( @@ -170,13 +216,15 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client = client_class() # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class() # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -189,26 +237,66 @@ def test_specialist_pool_service_client_client_options(client_class, transport_c client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "true"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc", "false"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio", "false") -]) -@mock.patch.object(SpecialistPoolServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceClient)) -@mock.patch.object(SpecialistPoolServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(SpecialistPoolServiceAsyncClient)) + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "true", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "false", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_specialist_pool_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: ssl_channel_creds = mock.Mock() - with mock.patch('grpc.ssl_channel_credentials', return_value=ssl_channel_creds): + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): patched.return_value = None client = client_class(client_options=options) @@ -231,11 +319,21 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: - with mock.patch('google.auth.transport.grpc.SslCredentials.ssl_credentials', new_callable=mock.PropertyMock) as ssl_credentials_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: if use_client_cert_env == "false": is_mtls_mock.return_value = False ssl_credentials_mock.return_value = None @@ -245,7 +343,9 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl is_mtls_mock.return_value = True ssl_credentials_mock.return_value = mock.Mock() expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ssl_credentials_mock.return_value + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) patched.return_value = None client = client_class() @@ -260,10 +360,17 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.grpc.SslCredentials.__init__', return_value=None): - with mock.patch('google.auth.transport.grpc.SslCredentials.is_mtls', new_callable=mock.PropertyMock) as is_mtls_mock: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: is_mtls_mock.return_value = False patched.return_value = None client = client_class() @@ -278,16 +385,27 @@ def test_specialist_pool_service_client_mtls_env_auto(client_class, transport_cl ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_specialist_pool_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -300,16 +418,28 @@ def test_specialist_pool_service_client_client_options_scopes(client_class, tran client_info=transports.base.DEFAULT_CLIENT_INFO, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (SpecialistPoolServiceClient, transports.SpecialistPoolServiceGrpcTransport, "grpc"), - (SpecialistPoolServiceAsyncClient, transports.SpecialistPoolServiceGrpcAsyncIOTransport, "grpc_asyncio") -]) -def test_specialist_pool_service_client_client_options_credentials_file(client_class, transport_class, transport_name): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options) patched.assert_called_once_with( @@ -324,10 +454,12 @@ def test_specialist_pool_service_client_client_options_credentials_file(client_c def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = SpecialistPoolServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -340,10 +472,12 @@ def test_specialist_pool_service_client_client_options_from_dict(): ) -def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.CreateSpecialistPoolRequest): +def test_create_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -352,10 +486,10 @@ def test_create_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_specialist_pool(request) @@ -374,10 +508,12 @@ def test_create_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.CreateSpecialistPoolRequest): +async def test_create_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -386,11 +522,11 @@ async def test_create_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_specialist_pool(request) @@ -418,13 +554,13 @@ def test_create_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.create_specialist_pool(request) @@ -435,10 +571,7 @@ def test_create_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -450,13 +583,15 @@ async def test_create_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.create_specialist_pool(request) @@ -467,10 +602,7 @@ async def test_create_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_create_specialist_pool_flattened(): @@ -480,16 +612,16 @@ def test_create_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -497,9 +629,11 @@ def test_create_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) def test_create_specialist_pool_flattened_error(): @@ -512,8 +646,8 @@ def test_create_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) @@ -525,19 +659,19 @@ async def test_create_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_specialist_pool), - '__call__') as call: + type(client.transport.create_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_specialist_pool( - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -545,9 +679,11 @@ async def test_create_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) @pytest.mark.asyncio @@ -561,15 +697,17 @@ async def test_create_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.create_specialist_pool( specialist_pool_service.CreateSpecialistPoolRequest(), - parent='parent_value', - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), ) -def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.GetSpecialistPoolRequest): +def test_get_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -578,20 +716,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool( - name='name_value', - - display_name='display_name_value', - + name="name_value", + display_name="display_name_value", specialist_managers_count=2662, - - specialist_manager_emails=['specialist_manager_emails_value'], - - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], ) response = client.get_specialist_pool(request) @@ -606,15 +739,15 @@ def test_get_specialist_pool(transport: str = 'grpc', request_type=specialist_po assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] def test_get_specialist_pool_from_dict(): @@ -622,10 +755,12 @@ def test_get_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.GetSpecialistPoolRequest): +async def test_get_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -634,16 +769,18 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', reques # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool( - name='name_value', - display_name='display_name_value', - specialist_managers_count=2662, - specialist_manager_emails=['specialist_manager_emails_value'], - pending_data_labeling_jobs=['pending_data_labeling_jobs_value'], - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + ) response = await client.get_specialist_pool(request) @@ -656,15 +793,15 @@ async def test_get_specialist_pool_async(transport: str = 'grpc_asyncio', reques # Establish that the response is the type that we expect. assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == 'name_value' + assert response.name == "name_value" - assert response.display_name == 'display_name_value' + assert response.display_name == "display_name_value" assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ['specialist_manager_emails_value'] + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ['pending_data_labeling_jobs_value'] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] @pytest.mark.asyncio @@ -680,12 +817,12 @@ def test_get_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: call.return_value = specialist_pool.SpecialistPool() client.get_specialist_pool(request) @@ -697,10 +834,7 @@ def test_get_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -712,13 +846,15 @@ async def test_get_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) await client.get_specialist_pool(request) @@ -729,10 +865,7 @@ async def test_get_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_get_specialist_pool_flattened(): @@ -742,23 +875,21 @@ def test_get_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.get_specialist_pool( - name='name_value', - ) + client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_get_specialist_pool_flattened_error(): @@ -770,8 +901,7 @@ def test_get_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) @@ -783,24 +913,24 @@ async def test_get_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.get_specialist_pool), - '__call__') as call: + type(client.transport.get_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool.SpecialistPool() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool.SpecialistPool()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.get_specialist_pool( - name='name_value', - ) + response = await client.get_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -813,15 +943,16 @@ async def test_get_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", ) -def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_pool_service.ListSpecialistPoolsRequest): +def test_list_specialist_pools( + transport: str = "grpc", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -830,12 +961,11 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - + next_page_token="next_page_token_value", ) response = client.list_specialist_pools(request) @@ -850,7 +980,7 @@ def test_list_specialist_pools(transport: str = 'grpc', request_type=specialist_ assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_specialist_pools_from_dict(): @@ -858,10 +988,12 @@ def test_list_specialist_pools_from_dict(): @pytest.mark.asyncio -async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.ListSpecialistPoolsRequest): +async def test_list_specialist_pools_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -870,12 +1002,14 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', requ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_specialist_pools(request) @@ -888,7 +1022,7 @@ async def test_list_specialist_pools_async(transport: str = 'grpc_asyncio', requ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -904,12 +1038,12 @@ def test_list_specialist_pools_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() client.list_specialist_pools(request) @@ -921,10 +1055,7 @@ def test_list_specialist_pools_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] @pytest.mark.asyncio @@ -936,13 +1067,15 @@ async def test_list_specialist_pools_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = 'parent/value' + request.parent = "parent/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) await client.list_specialist_pools(request) @@ -953,10 +1086,7 @@ async def test_list_specialist_pools_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'parent=parent/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] def test_list_specialist_pools_flattened(): @@ -966,23 +1096,21 @@ def test_list_specialist_pools_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.list_specialist_pools( - parent='parent_value', - ) + client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" def test_list_specialist_pools_flattened_error(): @@ -994,8 +1122,7 @@ def test_list_specialist_pools_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) @@ -1007,24 +1134,24 @@ async def test_list_specialist_pools_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(specialist_pool_service.ListSpecialistPoolsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.list_specialist_pools( - parent='parent_value', - ) + response = await client.list_specialist_pools(parent="parent_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].parent == 'parent_value' + assert args[0].parent == "parent_value" @pytest.mark.asyncio @@ -1037,20 +1164,17 @@ async def test_list_specialist_pools_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), - parent='parent_value', + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", ) def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1059,17 +1183,14 @@ def test_list_specialist_pools_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1082,9 +1203,7 @@ def test_list_specialist_pools_pager(): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_specialist_pools(request={}) @@ -1092,18 +1211,16 @@ def test_list_specialist_pools_pager(): results = [i for i in pager] assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in results) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) + def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials, - ) + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__') as call: + type(client.transport.list_specialist_pools), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1112,17 +1229,14 @@ def test_list_specialist_pools_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1133,9 +1247,10 @@ def test_list_specialist_pools_pages(): RuntimeError, ) pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_specialist_pools_async_pager(): client = SpecialistPoolServiceAsyncClient( @@ -1144,8 +1259,10 @@ async def test_list_specialist_pools_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1154,17 +1271,14 @@ async def test_list_specialist_pools_async_pager(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1175,14 +1289,14 @@ async def test_list_specialist_pools_async_pager(): RuntimeError, ) async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == 'abc' + assert async_pager.next_page_token == "abc" responses = [] async for response in async_pager: responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) - for i in responses) + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) + @pytest.mark.asyncio async def test_list_specialist_pools_async_pages(): @@ -1192,8 +1306,10 @@ async def test_list_specialist_pools_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_specialist_pools), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( specialist_pool_service.ListSpecialistPoolsResponse( @@ -1202,17 +1318,14 @@ async def test_list_specialist_pools_async_pages(): specialist_pool.SpecialistPool(), specialist_pool.SpecialistPool(), ], - next_page_token='abc', + next_page_token="abc", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], - next_page_token='def', + specialist_pools=[], next_page_token="def", ), specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - ], - next_page_token='ghi', + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", ), specialist_pool_service.ListSpecialistPoolsResponse( specialist_pools=[ @@ -1225,14 +1338,16 @@ async def test_list_specialist_pools_async_pages(): pages = [] async for page_ in (await client.list_specialist_pools(request={})).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): +def test_delete_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1241,10 +1356,10 @@ def test_delete_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.delete_specialist_pool(request) @@ -1263,10 +1378,12 @@ def test_delete_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.DeleteSpecialistPoolRequest): +async def test_delete_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1275,11 +1392,11 @@ async def test_delete_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.delete_specialist_pool(request) @@ -1307,13 +1424,13 @@ def test_delete_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.delete_specialist_pool(request) @@ -1324,10 +1441,7 @@ def test_delete_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] @pytest.mark.asyncio @@ -1339,13 +1453,15 @@ async def test_delete_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = 'name/value' + request.name = "name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.delete_specialist_pool(request) @@ -1356,10 +1472,7 @@ async def test_delete_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name/value', - ) in kw['metadata'] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] def test_delete_specialist_pool_flattened(): @@ -1369,23 +1482,21 @@ def test_delete_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - client.delete_specialist_pool( - name='name_value', - ) + client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" def test_delete_specialist_pool_flattened_error(): @@ -1397,8 +1508,7 @@ def test_delete_specialist_pool_flattened_error(): # fields is an error. with pytest.raises(ValueError): client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) @@ -1410,26 +1520,24 @@ async def test_delete_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_specialist_pool), - '__call__') as call: + type(client.transport.delete_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. - response = await client.delete_specialist_pool( - name='name_value', - ) + response = await client.delete_specialist_pool(name="name_value",) # Establish that the underlying call was made with the expected # request object values. assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].name == 'name_value' + assert args[0].name == "name_value" @pytest.mark.asyncio @@ -1442,15 +1550,16 @@ async def test_delete_specialist_pool_flattened_error_async(): # fields is an error. with pytest.raises(ValueError): await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), - name='name_value', + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", ) -def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): +def test_update_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1459,10 +1568,10 @@ def test_update_specialist_pool(transport: str = 'grpc', request_type=specialist # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.update_specialist_pool(request) @@ -1481,10 +1590,12 @@ def test_update_specialist_pool_from_dict(): @pytest.mark.asyncio -async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', request_type=specialist_pool_service.UpdateSpecialistPoolRequest): +async def test_update_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, @@ -1493,11 +1604,11 @@ async def test_update_specialist_pool_async(transport: str = 'grpc_asyncio', req # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.update_specialist_pool(request) @@ -1525,13 +1636,13 @@ def test_update_specialist_pool_field_headers(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = operations_pb2.Operation(name='operations/op') + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") client.update_specialist_pool(request) @@ -1543,9 +1654,9 @@ def test_update_specialist_pool_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1557,13 +1668,15 @@ async def test_update_specialist_pool_field_headers_async(): # Any value that is part of the HTTP/1.1 URI should be sent as # a field header. Set these to a non-empty value. request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = 'specialist_pool.name/value' + request.specialist_pool.name = "specialist_pool.name/value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(operations_pb2.Operation(name='operations/op')) + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) await client.update_specialist_pool(request) @@ -1575,9 +1688,9 @@ async def test_update_specialist_pool_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'specialist_pool.name=specialist_pool.name/value', - ) in kw['metadata'] + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] def test_update_specialist_pool_flattened(): @@ -1587,16 +1700,16 @@ def test_update_specialist_pool_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1604,9 +1717,11 @@ def test_update_specialist_pool_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) def test_update_specialist_pool_flattened_error(): @@ -1619,8 +1734,8 @@ def test_update_specialist_pool_flattened_error(): with pytest.raises(ValueError): client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1632,19 +1747,19 @@ async def test_update_specialist_pool_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_specialist_pool), - '__call__') as call: + type(client.transport.update_specialist_pool), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1652,9 +1767,11 @@ async def test_update_specialist_pool_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool(name='name_value') + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) - assert args[0].update_mask == field_mask.FieldMask(paths=['paths_value']) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) @pytest.mark.asyncio @@ -1668,8 +1785,8 @@ async def test_update_specialist_pool_flattened_error_async(): with pytest.raises(ValueError): await client.update_specialist_pool( specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name='name_value'), - update_mask=field_mask.FieldMask(paths=['paths_value']), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), ) @@ -1680,8 +1797,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport=transport, + credentials=credentials.AnonymousCredentials(), transport=transport, ) # It is an error to provide a credentials file and a transport instance. @@ -1700,8 +1816,7 @@ def test_credentials_transport_error(): ) with pytest.raises(ValueError): client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, + client_options={"scopes": ["1", "2"]}, transport=transport, ) @@ -1729,13 +1844,16 @@ def test_transport_get_channel(): assert channel -@pytest.mark.parametrize("transport_class", [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport -]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() @@ -1746,10 +1864,7 @@ def test_transport_grpc_default(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), ) - assert isinstance( - client.transport, - transports.SpecialistPoolServiceGrpcTransport, - ) + assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) def test_specialist_pool_service_base_transport_error(): @@ -1757,13 +1872,15 @@ def test_specialist_pool_service_base_transport_error(): with pytest.raises(exceptions.DuplicateCredentialArgs): transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_specialist_pool_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__') as Transport: + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.SpecialistPoolServiceTransport( credentials=credentials.AnonymousCredentials(), @@ -1772,12 +1889,12 @@ def test_specialist_pool_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_specialist_pool', - 'get_specialist_pool', - 'list_specialist_pools', - 'delete_specialist_pool', - 'update_specialist_pool', - ) + "create_specialist_pool", + "get_specialist_pool", + "list_specialist_pools", + "delete_specialist_pool", + "update_specialist_pool", + ) for method in methods: with pytest.raises(NotImplementedError): getattr(transport, method)(request=object()) @@ -1790,23 +1907,28 @@ def test_specialist_pool_service_base_transport(): def test_specialist_pool_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(auth, 'load_credentials_from_file') as load_creds, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", + credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) def test_specialist_pool_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, 'default') as adc, mock.patch('google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (credentials.AnonymousCredentials(), None) transport = transports.SpecialistPoolServiceTransport() @@ -1815,11 +1937,11 @@ def test_specialist_pool_service_base_transport_with_adc(): def test_specialist_pool_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) SpecialistPoolServiceClient() - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id=None, ) @@ -1827,37 +1949,43 @@ def test_specialist_pool_service_auth_adc(): def test_specialist_pool_service_transport_auth_adc(): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport(host="squid.clam.whelk", quota_project_id="octopus") - adc.assert_called_once_with(scopes=( - 'https://www.googleapis.com/auth/cloud-platform',), + transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), quota_project_id="octopus", ) + def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:443' + assert client.transport._host == "aiplatform.googleapis.com:443" def test_specialist_pool_service_host_with_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='aiplatform.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), ) - assert client.transport._host == 'aiplatform.googleapis.com:8000' + assert client.transport._host == "aiplatform.googleapis.com:8000" def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.insecure_channel('http://localhost/') + channel = grpc.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" @@ -1865,24 +1993,33 @@ def test_specialist_pool_service_grpc_transport_channel(): def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel('http://localhost/') + channel = aio.insecure_channel("http://localhost/") # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" assert transport._ssl_channel_credentials == None -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1891,7 +2028,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( cred = credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, 'default') as adc: + with mock.patch.object(auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1907,9 +2044,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "mtls.squid.clam.whelk:443", credentials=cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1917,17 +2052,23 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred -@pytest.mark.parametrize("transport_class", [transports.SpecialistPoolServiceGrpcTransport, transports.SpecialistPoolServiceGrpcAsyncIOTransport]) -def test_specialist_pool_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel", autospec=True) as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -1944,9 +2085,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( "mtls.squid.clam.whelk:443", credentials=mock_cred, credentials_file=None, - scopes=( - 'https://www.googleapis.com/auth/cloud-platform', - ), + scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, ) @@ -1955,16 +2094,12 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc( def test_specialist_pool_service_grpc_lro_client(): client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc', + credentials=credentials.AnonymousCredentials(), transport="grpc", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client @@ -1972,36 +2107,36 @@ def test_specialist_pool_service_grpc_lro_client(): def test_specialist_pool_service_grpc_lro_async_client(): client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - transport='grpc_asyncio', + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) transport = client.transport # Ensure that we have a api-core operations client. - assert isinstance( - transport.operations_client, - operations_v1.OperationsAsyncClient, - ) + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + def test_specialist_pool_path(): project = "squid" location = "clam" specialist_pool = "whelk" - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format(project=project, location=location, specialist_pool=specialist_pool, ) - actual = SpecialistPoolServiceClient.specialist_pool_path(project, location, specialist_pool) + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + actual = SpecialistPoolServiceClient.specialist_pool_path( + project, location, specialist_pool + ) assert expected == actual def test_parse_specialist_pool_path(): expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", - + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", } path = SpecialistPoolServiceClient.specialist_pool_path(**expected) @@ -2009,18 +2144,20 @@ def test_parse_specialist_pool_path(): actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) assert expected == actual def test_parse_common_billing_account_path(): expected = { - "billing_account": "mussel", - + "billing_account": "mussel", } path = SpecialistPoolServiceClient.common_billing_account_path(**expected) @@ -2028,18 +2165,18 @@ def test_parse_common_billing_account_path(): actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format(folder=folder,) actual = SpecialistPoolServiceClient.common_folder_path(folder) assert expected == actual def test_parse_common_folder_path(): expected = { - "folder": "nautilus", - + "folder": "nautilus", } path = SpecialistPoolServiceClient.common_folder_path(**expected) @@ -2047,18 +2184,18 @@ def test_parse_common_folder_path(): actual = SpecialistPoolServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format(organization=organization,) actual = SpecialistPoolServiceClient.common_organization_path(organization) assert expected == actual def test_parse_common_organization_path(): expected = { - "organization": "abalone", - + "organization": "abalone", } path = SpecialistPoolServiceClient.common_organization_path(**expected) @@ -2066,18 +2203,18 @@ def test_parse_common_organization_path(): actual = SpecialistPoolServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format(project=project,) actual = SpecialistPoolServiceClient.common_project_path(project) assert expected == actual def test_parse_common_project_path(): expected = { - "project": "clam", - + "project": "clam", } path = SpecialistPoolServiceClient.common_project_path(**expected) @@ -2085,20 +2222,22 @@ def test_parse_common_project_path(): actual = SpecialistPoolServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) actual = SpecialistPoolServiceClient.common_location_path(project, location) assert expected == actual def test_parse_common_location_path(): expected = { - "project": "oyster", - "location": "nudibranch", - + "project": "oyster", + "location": "nudibranch", } path = SpecialistPoolServiceClient.common_location_path(**expected) @@ -2110,17 +2249,19 @@ def test_parse_common_location_path(): def test_client_withDEFAULT_CLIENT_INFO(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.SpecialistPoolServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = SpecialistPoolServiceClient.get_transport_class() transport = transport_class( - credentials=credentials.AnonymousCredentials(), - client_info=client_info, + credentials=credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) From 8f9c589c4182502c8caa94840dd6e3cc0c52b157 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 3 Nov 2020 12:54:38 -0800 Subject: [PATCH 05/12] update tests --- synth.metadata | 2 +- .../test_prediction_service.py | 39 +++++++++++++++---- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/synth.metadata b/synth.metadata index bc34821e1f..9399d8c2e3 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,7 +4,7 @@ "git": { "name": ".", "remote": "https://github.com/dizcology/python-aiplatform.git", - "sha": "60263c04ffd04dabd7cc95c138b9f1c87566208c" + "sha": "81da030c0af8902fd54c8e7b5e92255a532d0efb" } }, { diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index 2e91a47bf4..0d30aac85e 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -480,6 +480,7 @@ def test_predict( assert args[0] == prediction_service.PredictRequest() # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) assert response.deployed_model_id == "deployed_model_id_value" @@ -490,14 +491,16 @@ def test_predict_from_dict(): @pytest.mark.asyncio -async def test_predict_async(transport: str = "grpc_asyncio"): +async def test_predict_async( + transport: str = "grpc_asyncio", request_type=prediction_service.PredictRequest +): client = PredictionServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.PredictRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.predict), "__call__") as call: @@ -514,7 +517,7 @@ async def test_predict_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == prediction_service.PredictRequest() # Establish that the response is the type that we expect. assert isinstance(response, prediction_service.PredictResponse) @@ -522,6 +525,11 @@ async def test_predict_async(transport: str = "grpc_asyncio"): assert response.deployed_model_id == "deployed_model_id_value" +@pytest.mark.asyncio +async def test_predict_async_from_dict(): + await test_predict_async(request_type=dict) + + def test_predict_field_headers(): client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) @@ -603,7 +611,9 @@ def test_predict_flattened(): ] # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) def test_predict_flattened_error(): @@ -702,6 +712,7 @@ def test_explain( assert args[0] == prediction_service.ExplainRequest() # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.ExplainResponse) assert response.deployed_model_id == "deployed_model_id_value" @@ -712,14 +723,16 @@ def test_explain_from_dict(): @pytest.mark.asyncio -async def test_explain_async(transport: str = "grpc_asyncio"): +async def test_explain_async( + transport: str = "grpc_asyncio", request_type=prediction_service.ExplainRequest +): client = PredictionServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.ExplainRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object(type(client.transport.explain), "__call__") as call: @@ -736,7 +749,7 @@ async def test_explain_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == prediction_service.ExplainRequest() # Establish that the response is the type that we expect. assert isinstance(response, prediction_service.ExplainResponse) @@ -744,6 +757,11 @@ async def test_explain_async(transport: str = "grpc_asyncio"): assert response.deployed_model_id == "deployed_model_id_value" +@pytest.mark.asyncio +async def test_explain_async_from_dict(): + await test_explain_async(request_type=dict) + + def test_explain_field_headers(): client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) @@ -826,7 +844,9 @@ def test_explain_flattened(): ] # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE) + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) assert args[0].deployed_model_id == "deployed_model_id_value" @@ -1094,6 +1114,7 @@ def test_prediction_service_grpc_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_prediction_service_grpc_asyncio_transport_channel(): @@ -1105,6 +1126,7 @@ def test_prediction_service_grpc_asyncio_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( @@ -1152,6 +1174,7 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source( quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( From 538e3e25af2ef1cbba0aa7e74fde8e5045e77cd3 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 3 Nov 2020 13:10:29 -0800 Subject: [PATCH 06/12] sphinx-build warning not as errors --- noxfile.py | 1 - 1 file changed, 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 4b2538e5c9..1797beebfd 100644 --- a/noxfile.py +++ b/noxfile.py @@ -158,7 +158,6 @@ def docs(session): shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( "sphinx-build", - "-W", # warnings as errors "-T", # show full traceback on exception "-N", # no colors "-b", From 1ab0ae248324dc1dfe62dcf3152e65c8d7335844 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 3 Nov 2020 13:20:14 -0800 Subject: [PATCH 07/12] fix prediction async unit tests --- .../aiplatform_v1beta1/test_prediction_service.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index 0d30aac85e..e47e0f62c5 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -663,9 +663,10 @@ async def test_predict_flattened_async(): struct.Value(null_value=struct.NullValue.NULL_VALUE) ] - assert args[0].parameters == struct.Value( - null_value=struct.NullValue.NULL_VALUE - ) + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) @pytest.mark.asyncio @@ -900,9 +901,10 @@ async def test_explain_flattened_async(): struct.Value(null_value=struct.NullValue.NULL_VALUE) ] - assert args[0].parameters == struct.Value( - null_value=struct.NullValue.NULL_VALUE - ) + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) assert args[0].deployed_model_id == "deployed_model_id_value" From 100aa3a2aac1519a1efa71767b27d3214b96a589 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 3 Nov 2020 13:37:17 -0800 Subject: [PATCH 08/12] add metadata to setup.py --- setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.py b/setup.py index 82468cded3..016f2cc59d 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,10 @@ version="0.3.0", packages=setuptools.PEP420PackageFinder.find(), namespace_packages=("google", "google.cloud"), + author="Google LLC", + author_email="googleapis-packages@google.com", + license="Apache 2.0", + url="https://github.com/googleapis/python-aiplatform", platforms="Posix; MacOS X; Windows", include_package_data=True, install_requires=( From 994bfc72e8576b595a39da28676bcce31235ccdd Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 3 Nov 2020 16:42:52 -0800 Subject: [PATCH 09/12] export gapic to aiplatform --- google/cloud/aiplatform/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 17172cbb56..28e26764b9 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -15,4 +15,6 @@ # limitations under the License. # +from google.cloud.aiplatform import gapic + __all__ = () From 2176333836b4e5a48830b04c0b5da7571cef5e8a Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Tue, 3 Nov 2020 16:48:52 -0800 Subject: [PATCH 10/12] update dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 016f2cc59d..84c1995b6d 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ install_requires=( "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", "libcst >= 0.2.5", - "proto-plus >= 1.4.0", + "proto-plus >= 1.10.1", "mock >= 4.0.2", "google-cloud-storage >= 1.26.0", ), From aad4a7e716688b1297d158af9504e7201d2bc3bf Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Wed, 4 Nov 2020 08:27:19 -0800 Subject: [PATCH 11/12] fix lint --- google/cloud/aiplatform/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 28e26764b9..ec30029286 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -17,4 +17,4 @@ from google.cloud.aiplatform import gapic -__all__ = () +__all__ = (gapic,) From 0e8f5d3e4580cd32a0a082a3c708f474a8345fd1 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Wed, 4 Nov 2020 11:46:40 -0800 Subject: [PATCH 12/12] address review comments --- scripts/fixup_aiplatform_v1beta1_keywords.py | 239 ------------------- setup.py | 6 +- synth.py | 1 + 3 files changed, 4 insertions(+), 242 deletions(-) delete mode 100644 scripts/fixup_aiplatform_v1beta1_keywords.py diff --git a/scripts/fixup_aiplatform_v1beta1_keywords.py b/scripts/fixup_aiplatform_v1beta1_keywords.py deleted file mode 100644 index 4842dae628..0000000000 --- a/scripts/fixup_aiplatform_v1beta1_keywords.py +++ /dev/null @@ -1,239 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import argparse -import os -import libcst as cst -import pathlib -import sys -from typing import (Any, Callable, Dict, List, Sequence, Tuple) - - -def partition( - predicate: Callable[[Any], bool], - iterator: Sequence[Any] -) -> Tuple[List[Any], List[Any]]: - """A stable, out-of-place partition.""" - results = ([], []) - - for i in iterator: - results[int(predicate(i))].append(i) - - # Returns trueList, falseList - return results[1], results[0] - - -class aiplatformCallTransformer(cst.CSTTransformer): - CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') - METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'batch_migrate_resources': ('parent', 'migrate_resource_requests', ), - 'cancel_batch_prediction_job': ('name', ), - 'cancel_custom_job': ('name', ), - 'cancel_data_labeling_job': ('name', ), - 'cancel_hyperparameter_tuning_job': ('name', ), - 'cancel_training_pipeline': ('name', ), - 'create_batch_prediction_job': ('parent', 'batch_prediction_job', ), - 'create_custom_job': ('parent', 'custom_job', ), - 'create_data_labeling_job': ('parent', 'data_labeling_job', ), - 'create_dataset': ('parent', 'dataset', ), - 'create_endpoint': ('parent', 'endpoint', ), - 'create_hyperparameter_tuning_job': ('parent', 'hyperparameter_tuning_job', ), - 'create_specialist_pool': ('parent', 'specialist_pool', ), - 'create_training_pipeline': ('parent', 'training_pipeline', ), - 'delete_batch_prediction_job': ('name', ), - 'delete_custom_job': ('name', ), - 'delete_data_labeling_job': ('name', ), - 'delete_dataset': ('name', ), - 'delete_endpoint': ('name', ), - 'delete_hyperparameter_tuning_job': ('name', ), - 'delete_model': ('name', ), - 'delete_specialist_pool': ('name', 'force', ), - 'delete_training_pipeline': ('name', ), - 'deploy_model': ('endpoint', 'deployed_model', 'traffic_split', ), - 'explain': ('endpoint', 'instances', 'parameters', 'deployed_model_id', ), - 'export_data': ('name', 'export_config', ), - 'export_model': ('name', 'output_config', ), - 'get_annotation_spec': ('name', 'read_mask', ), - 'get_batch_prediction_job': ('name', ), - 'get_custom_job': ('name', ), - 'get_data_labeling_job': ('name', ), - 'get_dataset': ('name', 'read_mask', ), - 'get_endpoint': ('name', ), - 'get_hyperparameter_tuning_job': ('name', ), - 'get_model': ('name', ), - 'get_model_evaluation': ('name', ), - 'get_model_evaluation_slice': ('name', ), - 'get_specialist_pool': ('name', ), - 'get_training_pipeline': ('name', ), - 'import_data': ('name', 'import_configs', ), - 'list_annotations': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_batch_prediction_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_custom_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_data_items': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_data_labeling_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_datasets': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_endpoints': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_hyperparameter_tuning_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_model_evaluations': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_model_evaluation_slices': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_models': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_specialist_pools': ('parent', 'page_size', 'page_token', 'read_mask', ), - 'list_training_pipelines': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'predict': ('endpoint', 'instances', 'parameters', ), - 'search_migratable_resources': ('parent', 'page_size', 'page_token', ), - 'undeploy_model': ('endpoint', 'deployed_model_id', 'traffic_split', ), - 'update_dataset': ('dataset', 'update_mask', ), - 'update_endpoint': ('endpoint', 'update_mask', ), - 'update_model': ('model', 'update_mask', ), - 'update_specialist_pool': ('specialist_pool', 'update_mask', ), - 'upload_model': ('parent', 'model', ), - - } - - def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: - try: - key = original.func.attr.value - kword_params = self.METHOD_TO_PARAMS[key] - except (AttributeError, KeyError): - # Either not a method from the API or too convoluted to be sure. - return updated - - # If the existing code is valid, keyword args come after positional args. - # Therefore, all positional args must map to the first parameters. - args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) - if any(k.keyword.value == "request" for k in kwargs): - # We've already fixed this file, don't fix it again. - return updated - - kwargs, ctrl_kwargs = partition( - lambda a: not a.keyword.value in self.CTRL_PARAMS, - kwargs - ) - - args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] - ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) - for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) - - request_arg = cst.Arg( - value=cst.Dict([ - cst.DictElement( - cst.SimpleString("'{}'".format(name)), - cst.Element(value=arg.value) - ) - # Note: the args + kwargs looks silly, but keep in mind that - # the control parameters had to be stripped out, and that - # those could have been passed positionally or by keyword. - for name, arg in zip(kword_params, args + kwargs)]), - keyword=cst.Name("request") - ) - - return updated.with_changes( - args=[request_arg] + ctrl_kwargs - ) - - -def fix_files( - in_dir: pathlib.Path, - out_dir: pathlib.Path, - *, - transformer=aiplatformCallTransformer(), -): - """Duplicate the input dir to the output dir, fixing file method calls. - - Preconditions: - * in_dir is a real directory - * out_dir is a real, empty directory - """ - pyfile_gen = ( - pathlib.Path(os.path.join(root, f)) - for root, _, files in os.walk(in_dir) - for f in files if os.path.splitext(f)[1] == ".py" - ) - - for fpath in pyfile_gen: - with open(fpath, 'r') as f: - src = f.read() - - # Parse the code and insert method call fixes. - tree = cst.parse_module(src) - updated = tree.visit(transformer) - - # Create the path and directory structure for the new file. - updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) - updated_path.parent.mkdir(parents=True, exist_ok=True) - - # Generate the updated source file at the corresponding path. - with open(updated_path, 'w') as f: - f.write(updated.code) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="""Fix up source that uses the aiplatform client library. - -The existing sources are NOT overwritten but are copied to output_dir with changes made. - -Note: This tool operates at a best-effort level at converting positional - parameters in client method calls to keyword based parameters. - Cases where it WILL FAIL include - A) * or ** expansion in a method call. - B) Calls via function or method alias (includes free function calls) - C) Indirect or dispatched calls (e.g. the method is looked up dynamically) - - These all constitute false negatives. The tool will also detect false - positives when an API method shares a name with another method. -""") - parser.add_argument( - '-d', - '--input-directory', - required=True, - dest='input_dir', - help='the input directory to walk for python files to fix up', - ) - parser.add_argument( - '-o', - '--output-directory', - required=True, - dest='output_dir', - help='the directory to output files fixed via un-flattening', - ) - args = parser.parse_args() - input_dir = pathlib.Path(args.input_dir) - output_dir = pathlib.Path(args.output_dir) - if not input_dir.is_dir(): - print( - f"input directory '{input_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if not output_dir.is_dir(): - print( - f"output directory '{output_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if os.listdir(output_dir): - print( - f"output directory '{output_dir}' is not empty", - file=sys.stderr, - ) - sys.exit(-1) - - fix_files(input_dir, output_dir) diff --git a/setup.py b/setup.py index 84c1995b6d..1b0066a279 100644 --- a/setup.py +++ b/setup.py @@ -34,12 +34,12 @@ "libcst >= 0.2.5", "proto-plus >= 1.10.1", "mock >= 4.0.2", - "google-cloud-storage >= 1.26.0", + "google-cloud-storage >= 1.26.0, < 2.0.0dev", ), python_requires=">=3.6", - scripts=["scripts/fixup_aiplatform_v1beta1_keywords.py",], + scripts=[], classifiers=[ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Operating System :: OS Independent", "Programming Language :: Python :: 3.6", diff --git a/synth.py b/synth.py index b370daadeb..935bf20fce 100644 --- a/synth.py +++ b/synth.py @@ -43,6 +43,7 @@ "setup.py", "README.rst", "docs/index.rst", + "scripts/fixup_aiplatform_v1beta1_keywords.py", "google/cloud/aiplatform/__init__.py", "tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py", ],