From 4d708cc906af2fefb04c8c85f4ec05f8e6560bf3 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Mon, 1 Mar 2021 15:51:49 -0800 Subject: [PATCH 1/7] chore: release 0.5.1 (#240) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 14 ++++++++++++++ setup.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 315d5a8da1..ea9ca7a7b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +### [0.5.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.5.0...v0.5.1) (2021-03-01) + + +### Bug Fixes + +* fix create data labeling job samples tests ([#244](https://www.github.com/googleapis/python-aiplatform/issues/244)) ([3c440de](https://www.github.com/googleapis/python-aiplatform/commit/3c440dea14ad4d04b05ebf17ba4bcb031fe95b3e)) +* fix predict sample tests for proto-plus==1.14.2 ([#250](https://www.github.com/googleapis/python-aiplatform/issues/250)) ([b1c9d88](https://www.github.com/googleapis/python-aiplatform/commit/b1c9d88646f00b034e2576890406325db5384f10)) +* fix update export model sample, and add sample test ([#239](https://www.github.com/googleapis/python-aiplatform/issues/239)) ([20b8859](https://www.github.com/googleapis/python-aiplatform/commit/20b88592da3dd7344c7053d7fe652115ed42e4aa)) + + +### Documentation + +* update index.rst to include v1 ([#246](https://www.github.com/googleapis/python-aiplatform/issues/246)) ([82193ef](https://www.github.com/googleapis/python-aiplatform/commit/82193ef401258b17fd20895e2b0f6c95a39a24a1)) + ## [0.5.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.4.0...v0.5.0) (2021-02-17) diff --git a/setup.py b/setup.py index bb807572be..a290702738 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ import setuptools # type: ignore name = "google-cloud-aiplatform" -version = "0.5.0" +version = "0.5.1" description = "Cloud AI Platform API client library" package_root = os.path.abspath(os.path.dirname(__file__)) From 267af10ec6f9d9701afef6284f20b1e7475574e4 Mon Sep 17 00:00:00 2001 From: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Date: Wed, 3 Mar 2021 09:12:30 -0700 Subject: [PATCH 2/7] chore: allow merge commits (#252) Merge commits are used for syncing `dev` and `master`. --- .github/sync-repo-settings.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/sync-repo-settings.yaml b/.github/sync-repo-settings.yaml index 4930eaccc6..b703be9596 100644 --- a/.github/sync-repo-settings.yaml +++ b/.github/sync-repo-settings.yaml @@ -1,5 +1,6 @@ # https://github.com/googleapis/repo-automation-bots/tree/master/packages/sync-repo-settings # Rules for master branch protection +mergeCommitAllowed: true branchProtectionRules: # Identifies the protection rule pattern. Name of the branch to be protected. # Defaults to `master` From b4338dc6c1c15527d33644de2cb068981b0b9b9f Mon Sep 17 00:00:00 2001 From: WhiteSource Renovate Date: Wed, 3 Mar 2021 17:12:57 +0100 Subject: [PATCH 3/7] chore(deps): update dependency google-cloud-aiplatform to v0.5.1 (#251) --- samples/snippets/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/snippets/requirements.txt b/samples/snippets/requirements.txt index 7371c65396..b9fd33d5c1 100644 --- a/samples/snippets/requirements.txt +++ b/samples/snippets/requirements.txt @@ -1,3 +1,3 @@ pytest==6.2.2 google-cloud-storage>=1.26.0, <2.0.0dev -google-cloud-aiplatform==0.5.0 +google-cloud-aiplatform==0.5.1 From 116a29b1efcebb15bad14c3c36d3591c09ef10be Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Wed, 3 Mar 2021 13:22:34 -0800 Subject: [PATCH 4/7] fix: skip create data labeling job sample tests (#254) --- .../create_data_labeling_job_active_learning_sample_test.py | 1 + .../create_data_labeling_job_image_segmentation_sample_test.py | 1 + samples/snippets/create_data_labeling_job_images_sample_test.py | 1 + samples/snippets/create_data_labeling_job_sample_test.py | 1 + .../create_data_labeling_job_specialist_pool_sample_test.py | 1 + samples/snippets/create_data_labeling_job_video_sample_test.py | 1 + 6 files changed, 6 insertions(+) diff --git a/samples/snippets/create_data_labeling_job_active_learning_sample_test.py b/samples/snippets/create_data_labeling_job_active_learning_sample_test.py index 8ac753eb52..4ec5394535 100644 --- a/samples/snippets/create_data_labeling_job_active_learning_sample_test.py +++ b/samples/snippets/create_data_labeling_job_active_learning_sample_test.py @@ -39,6 +39,7 @@ def teardown(teardown_data_labeling_job): # Creating a data labeling job for images +@pytest.mark.skip(reason="Flaky job state.") def test_create_data_labeling_job_active_learning_sample(capsys, shared_state): create_data_labeling_job_active_learning_sample.create_data_labeling_job_active_learning_sample( diff --git a/samples/snippets/create_data_labeling_job_image_segmentation_sample_test.py b/samples/snippets/create_data_labeling_job_image_segmentation_sample_test.py index 79ec63f1e9..e5f365d234 100644 --- a/samples/snippets/create_data_labeling_job_image_segmentation_sample_test.py +++ b/samples/snippets/create_data_labeling_job_image_segmentation_sample_test.py @@ -40,6 +40,7 @@ def teardown(teardown_data_labeling_job): # Creating a data labeling job for images +@pytest.mark.skip(reason="Flaky job state.") def test_create_data_labeling_job_image_segmentation_sample(capsys, shared_state): dataset = f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}" diff --git a/samples/snippets/create_data_labeling_job_images_sample_test.py b/samples/snippets/create_data_labeling_job_images_sample_test.py index 07ecda5d14..026a0fbd58 100644 --- a/samples/snippets/create_data_labeling_job_images_sample_test.py +++ b/samples/snippets/create_data_labeling_job_images_sample_test.py @@ -37,6 +37,7 @@ def teardown(teardown_data_labeling_job): # Creating a data labeling job for images +@pytest.mark.skip(reason="Flaky job state.") def test_ucaip_generated_create_data_labeling_job_sample(capsys, shared_state): dataset_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}" diff --git a/samples/snippets/create_data_labeling_job_sample_test.py b/samples/snippets/create_data_labeling_job_sample_test.py index 5a7b714685..847e452c0a 100644 --- a/samples/snippets/create_data_labeling_job_sample_test.py +++ b/samples/snippets/create_data_labeling_job_sample_test.py @@ -38,6 +38,7 @@ def teardown(teardown_data_labeling_job): # Creating a data labeling job for images +@pytest.mark.skip(reason="Flaky job state.") def test_ucaip_generated_create_data_labeling_job_sample(capsys, shared_state): dataset_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}" diff --git a/samples/snippets/create_data_labeling_job_specialist_pool_sample_test.py b/samples/snippets/create_data_labeling_job_specialist_pool_sample_test.py index ae7a70cba4..0f0f882c8c 100644 --- a/samples/snippets/create_data_labeling_job_specialist_pool_sample_test.py +++ b/samples/snippets/create_data_labeling_job_specialist_pool_sample_test.py @@ -40,6 +40,7 @@ def teardown(teardown_data_labeling_job): # Creating a data labeling job for images +@pytest.mark.skip(reason="Flaky job state.") def test_create_data_labeling_job_specialist_pool_sample(capsys, shared_state): dataset = f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}" diff --git a/samples/snippets/create_data_labeling_job_video_sample_test.py b/samples/snippets/create_data_labeling_job_video_sample_test.py index 53813e4e42..5d952ec552 100644 --- a/samples/snippets/create_data_labeling_job_video_sample_test.py +++ b/samples/snippets/create_data_labeling_job_video_sample_test.py @@ -37,6 +37,7 @@ def teardown(teardown_data_labeling_job): # Creating a data labeling job for images +@pytest.mark.skip(reason="Flaky job state.") def test_ucaip_generated_create_data_labeling_job_sample(capsys, shared_state): dataset_name = f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}" From 47b7f3815c18dd6d364cb5baaeae65a5644ff166 Mon Sep 17 00:00:00 2001 From: WhiteSource Renovate Date: Fri, 19 Mar 2021 19:21:49 +0100 Subject: [PATCH 5/7] chore(deps): update precommit hook pycqa/flake8 to v3.9.0 (#261) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a9024b15d7..32302e4883 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,6 @@ repos: hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.0 hooks: - id: flake8 From e5c1b1a4909d701efeb27f29af43a95516c51475 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Mon, 22 Mar 2021 09:47:33 -0700 Subject: [PATCH 6/7] feat: add Vizier service (#266) * feat: add Vizier service * chore: blacken Co-authored-by: Bu Sun Kim --- .gitignore | 4 +- .kokoro/build.sh | 10 + .kokoro/test-samples-against-head.sh | 28 + .kokoro/test-samples-impl.sh | 102 + .kokoro/test-samples.sh | 96 +- .pre-commit-config.yaml | 2 +- CONTRIBUTING.rst | 22 +- MANIFEST.in | 4 +- docs/aiplatform_v1beta1/services.rst | 1 + docs/aiplatform_v1beta1/vizier_service.rst | 11 + .../definition_v1/types/__init__.py | 4 +- .../schema/predict/instance/__init__.py | 1 + .../types/video_action_recognition.py | 4 +- .../types/video_classification.py | 4 +- .../types/video_object_tracking.py | 4 +- .../v1beta1/schema/predict/params/__init__.py | 1 + .../schema/predict/prediction/__init__.py | 5 +- .../predict/prediction_v1beta1/__init__.py | 2 - .../prediction_v1beta1/types/__init__.py | 2 - .../types/image_segmentation.py | 8 +- .../schema/trainingjob/definition/__init__.py | 13 +- .../definition_v1beta1/__init__.py | 8 +- .../definition_v1beta1/types/__init__.py | 12 +- .../types/automl_image_segmentation.py | 1 + .../types/automl_video_classification.py | 1 + .../export_evaluated_data_items_config.py | 13 +- .../services/dataset_service/async_client.py | 32 +- .../services/dataset_service/client.py | 5 +- .../services/dataset_service/pagers.py | 11 +- .../dataset_service/transports/grpc.py | 3 +- .../services/endpoint_service/async_client.py | 32 +- .../services/endpoint_service/client.py | 10 +- .../services/endpoint_service/pagers.py | 11 +- .../endpoint_service/transports/grpc.py | 3 +- .../services/job_service/async_client.py | 32 +- .../services/job_service/pagers.py | 11 +- .../services/job_service/transports/grpc.py | 3 +- .../migration_service/async_client.py | 32 +- .../services/migration_service/client.py | 5 +- .../services/migration_service/pagers.py | 11 +- .../migration_service/transports/grpc.py | 3 +- .../services/model_service/async_client.py | 32 +- .../services/model_service/pagers.py | 11 +- .../services/model_service/transports/grpc.py | 3 +- .../services/pipeline_service/async_client.py | 32 +- .../services/pipeline_service/pagers.py | 11 +- .../pipeline_service/transports/grpc.py | 3 +- .../prediction_service/async_client.py | 32 +- .../services/prediction_service/client.py | 5 +- .../prediction_service/transports/grpc.py | 3 +- .../specialist_pool_service/async_client.py | 32 +- .../specialist_pool_service/pagers.py | 11 +- .../transports/grpc.py | 3 +- google/cloud/aiplatform_v1/types/__init__.py | 392 +- google/cloud/aiplatform_v1beta1/__init__.py | 52 +- .../services/dataset_service/async_client.py | 34 +- .../services/dataset_service/client.py | 7 +- .../services/dataset_service/pagers.py | 11 +- .../dataset_service/transports/grpc.py | 3 +- .../services/endpoint_service/async_client.py | 32 +- .../services/endpoint_service/client.py | 10 +- .../services/endpoint_service/pagers.py | 11 +- .../endpoint_service/transports/grpc.py | 3 +- .../services/job_service/async_client.py | 41 +- .../services/job_service/client.py | 9 - .../services/job_service/pagers.py | 11 +- .../services/job_service/transports/grpc.py | 3 +- .../migration_service/async_client.py | 32 +- .../services/migration_service/client.py | 27 +- .../services/migration_service/pagers.py | 11 +- .../migration_service/transports/grpc.py | 3 +- .../services/model_service/async_client.py | 35 +- .../services/model_service/client.py | 3 - .../services/model_service/pagers.py | 11 +- .../services/model_service/transports/grpc.py | 3 +- .../services/pipeline_service/async_client.py | 35 +- .../services/pipeline_service/client.py | 3 - .../services/pipeline_service/pagers.py | 11 +- .../pipeline_service/transports/grpc.py | 3 +- .../prediction_service/async_client.py | 32 +- .../services/prediction_service/client.py | 10 +- .../prediction_service/transports/grpc.py | 3 +- .../specialist_pool_service/async_client.py | 33 +- .../specialist_pool_service/client.py | 1 - .../specialist_pool_service/pagers.py | 11 +- .../transports/grpc.py | 3 +- .../services/vizier_service/__init__.py | 24 + .../services/vizier_service/async_client.py | 1261 +++++ .../services/vizier_service/client.py | 1478 ++++++ .../services/vizier_service/pagers.py | 286 ++ .../vizier_service/transports/__init__.py | 35 + .../vizier_service/transports/base.py | 315 ++ .../vizier_service/transports/grpc.py | 685 +++ .../vizier_service/transports/grpc_asyncio.py | 702 +++ .../aiplatform_v1beta1/types/__init__.py | 504 +- .../types/accelerator_type.py | 2 - .../aiplatform_v1beta1/types/annotation.py | 4 +- .../types/annotation_spec.py | 2 +- .../types/batch_prediction_job.py | 5 +- .../aiplatform_v1beta1/types/custom_job.py | 17 +- .../aiplatform_v1beta1/types/data_item.py | 2 +- .../types/data_labeling_job.py | 8 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 2 +- .../types/dataset_service.py | 2 - .../aiplatform_v1beta1/types/endpoint.py | 13 +- .../types/endpoint_service.py | 2 +- .../types/explanation_metadata.py | 6 +- google/cloud/aiplatform_v1beta1/types/io.py | 8 +- .../aiplatform_v1beta1/types/job_service.py | 17 +- .../aiplatform_v1beta1/types/job_state.py | 1 + .../types/machine_resources.py | 87 +- .../types/migratable_resource.py | 9 +- .../types/migration_service.py | 7 +- .../cloud/aiplatform_v1beta1/types/model.py | 32 +- .../aiplatform_v1beta1/types/model_service.py | 3 - .../types/pipeline_service.py | 5 +- .../types/specialist_pool_service.py | 2 - .../cloud/aiplatform_v1beta1/types/study.py | 167 +- .../types/training_pipeline.py | 17 +- .../types/user_action_reference.py | 1 - .../types/vizier_service.py | 479 ++ noxfile.py | 23 +- renovate.json | 3 +- tests/unit/gapic/aiplatform_v1/__init__.py | 15 + .../aiplatform_v1/test_dataset_service.py | 174 +- .../aiplatform_v1/test_endpoint_service.py | 124 +- .../gapic/aiplatform_v1/test_job_service.py | 370 +- .../aiplatform_v1/test_migration_service.py | 48 +- .../gapic/aiplatform_v1/test_model_service.py | 182 +- .../aiplatform_v1/test_pipeline_service.py | 102 +- .../test_specialist_pool_service.py | 102 +- .../unit/gapic/aiplatform_v1beta1/__init__.py | 15 + .../test_dataset_service.py | 174 +- .../test_endpoint_service.py | 124 +- .../aiplatform_v1beta1/test_job_service.py | 370 +- .../test_migration_service.py | 76 +- .../aiplatform_v1beta1/test_model_service.py | 182 +- .../test_pipeline_service.py | 102 +- .../test_prediction_service.py | 2 +- .../test_specialist_pool_service.py | 102 +- .../aiplatform_v1beta1/test_vizier_service.py | 4228 +++++++++++++++++ 141 files changed, 13458 insertions(+), 884 deletions(-) create mode 100755 .kokoro/test-samples-against-head.sh create mode 100755 .kokoro/test-samples-impl.sh create mode 100644 docs/aiplatform_v1beta1/vizier_service.rst create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/client.py create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1beta1/types/vizier_service.py create mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py diff --git a/.gitignore b/.gitignore index b9daa52f11..b4243ced74 100644 --- a/.gitignore +++ b/.gitignore @@ -50,8 +50,10 @@ docs.metadata # Virtual environment env/ + +# Test logs coverage.xml -sponge_log.xml +*sponge_log.xml # System test environment variables. system_tests/local_test_setup diff --git a/.kokoro/build.sh b/.kokoro/build.sh index ef4eb9c094..35e4a0f6ce 100755 --- a/.kokoro/build.sh +++ b/.kokoro/build.sh @@ -40,6 +40,16 @@ python3 -m pip uninstall --yes --quiet nox-automation python3 -m pip install --upgrade --quiet nox python3 -m nox --version +# If this is a continuous build, send the test log to the FlakyBot. +# See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot. +if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"continuous"* ]]; then + cleanup() { + chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot + $KOKORO_GFILE_DIR/linux_amd64/flakybot + } + trap cleanup EXIT HUP +fi + # If NOX_SESSION is set, it only runs the specified session, # otherwise run all the sessions. if [[ -n "${NOX_SESSION:-}" ]]; then diff --git a/.kokoro/test-samples-against-head.sh b/.kokoro/test-samples-against-head.sh new file mode 100755 index 0000000000..8f0597f90d --- /dev/null +++ b/.kokoro/test-samples-against-head.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# A customized test runner for samples. +# +# For periodic builds, you can specify this file for testing against head. + +# `-e` enables the script to automatically fail when a command fails +# `-o pipefail` sets the exit code to the rightmost comment to exit with a non-zero +set -eo pipefail +# Enables `**` to include files nested inside sub-folders +shopt -s globstar + +cd github/python-aiplatform + +exec .kokoro/test-samples-impl.sh diff --git a/.kokoro/test-samples-impl.sh b/.kokoro/test-samples-impl.sh new file mode 100755 index 0000000000..cf5de74c17 --- /dev/null +++ b/.kokoro/test-samples-impl.sh @@ -0,0 +1,102 @@ +#!/bin/bash +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# `-e` enables the script to automatically fail when a command fails +# `-o pipefail` sets the exit code to the rightmost comment to exit with a non-zero +set -eo pipefail +# Enables `**` to include files nested inside sub-folders +shopt -s globstar + +# Exit early if samples directory doesn't exist +if [ ! -d "./samples" ]; then + echo "No tests run. `./samples` not found" + exit 0 +fi + +# Disable buffering, so that the logs stream through. +export PYTHONUNBUFFERED=1 + +# Debug: show build environment +env | grep KOKORO + +# Install nox +python3.6 -m pip install --upgrade --quiet nox + +# Use secrets acessor service account to get secrets +if [[ -f "${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" ]]; then + gcloud auth activate-service-account \ + --key-file="${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" \ + --project="cloud-devrel-kokoro-resources" +fi + +# This script will create 3 files: +# - testing/test-env.sh +# - testing/service-account.json +# - testing/client-secrets.json +./scripts/decrypt-secrets.sh + +source ./testing/test-env.sh +export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/testing/service-account.json + +# For cloud-run session, we activate the service account for gcloud sdk. +gcloud auth activate-service-account \ + --key-file "${GOOGLE_APPLICATION_CREDENTIALS}" + +export GOOGLE_CLIENT_SECRETS=$(pwd)/testing/client-secrets.json + +echo -e "\n******************** TESTING PROJECTS ********************" + +# Switch to 'fail at end' to allow all tests to complete before exiting. +set +e +# Use RTN to return a non-zero value if the test fails. +RTN=0 +ROOT=$(pwd) +# Find all requirements.txt in the samples directory (may break on whitespace). +for file in samples/**/requirements.txt; do + cd "$ROOT" + # Navigate to the project folder. + file=$(dirname "$file") + cd "$file" + + echo "------------------------------------------------------------" + echo "- testing $file" + echo "------------------------------------------------------------" + + # Use nox to execute the tests for the project. + python3.6 -m nox -s "$RUN_TESTS_SESSION" + EXIT=$? + + # If this is a periodic build, send the test log to the FlakyBot. + # See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot. + if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then + chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot + $KOKORO_GFILE_DIR/linux_amd64/flakybot + fi + + if [[ $EXIT -ne 0 ]]; then + RTN=1 + echo -e "\n Testing failed: Nox returned a non-zero exit code. \n" + else + echo -e "\n Testing completed.\n" + fi + +done +cd "$ROOT" + +# Workaround for Kokoro permissions issue: delete secrets +rm testing/{test-env.sh,client-secrets.json,service-account.json} + +exit "$RTN" diff --git a/.kokoro/test-samples.sh b/.kokoro/test-samples.sh index 4c034fa7c7..6bb4d5c30b 100755 --- a/.kokoro/test-samples.sh +++ b/.kokoro/test-samples.sh @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +# The default test runner for samples. +# +# For periodic builds, we rewinds the repo to the latest release, and +# run test-samples-impl.sh. # `-e` enables the script to automatically fail when a command fails # `-o pipefail` sets the exit code to the rightmost comment to exit with a non-zero @@ -24,87 +28,19 @@ cd github/python-aiplatform # Run periodic samples tests at latest release if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then + # preserving the test runner implementation. + cp .kokoro/test-samples-impl.sh "${TMPDIR}/test-samples-impl.sh" + echo "--- IMPORTANT IMPORTANT IMPORTANT ---" + echo "Now we rewind the repo back to the latest release..." LATEST_RELEASE=$(git describe --abbrev=0 --tags) git checkout $LATEST_RELEASE -fi - -# Exit early if samples directory doesn't exist -if [ ! -d "./samples" ]; then - echo "No tests run. `./samples` not found" - exit 0 -fi - -# Disable buffering, so that the logs stream through. -export PYTHONUNBUFFERED=1 - -# Debug: show build environment -env | grep KOKORO - -# Install nox -python3.6 -m pip install --upgrade --quiet nox - -# Use secrets acessor service account to get secrets -if [[ -f "${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" ]]; then - gcloud auth activate-service-account \ - --key-file="${KOKORO_GFILE_DIR}/secrets_viewer_service_account.json" \ - --project="cloud-devrel-kokoro-resources" -fi - -# This script will create 3 files: -# - testing/test-env.sh -# - testing/service-account.json -# - testing/client-secrets.json -./scripts/decrypt-secrets.sh - -source ./testing/test-env.sh -export GOOGLE_APPLICATION_CREDENTIALS=$(pwd)/testing/service-account.json - -# For cloud-run session, we activate the service account for gcloud sdk. -gcloud auth activate-service-account \ - --key-file "${GOOGLE_APPLICATION_CREDENTIALS}" - -export GOOGLE_CLIENT_SECRETS=$(pwd)/testing/client-secrets.json - -echo -e "\n******************** TESTING PROJECTS ********************" - -# Switch to 'fail at end' to allow all tests to complete before exiting. -set +e -# Use RTN to return a non-zero value if the test fails. -RTN=0 -ROOT=$(pwd) -# Find all requirements.txt in the samples directory (may break on whitespace). -for file in samples/**/requirements.txt; do - cd "$ROOT" - # Navigate to the project folder. - file=$(dirname "$file") - cd "$file" - - echo "------------------------------------------------------------" - echo "- testing $file" - echo "------------------------------------------------------------" - - # Use nox to execute the tests for the project. - python3.6 -m nox -s "$RUN_TESTS_SESSION" - EXIT=$? - - # If this is a periodic build, send the test log to the FlakyBot. - # See https://github.com/googleapis/repo-automation-bots/tree/master/packages/flakybot. - if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then - chmod +x $KOKORO_GFILE_DIR/linux_amd64/flakybot - $KOKORO_GFILE_DIR/linux_amd64/flakybot + echo "The current head is: " + echo $(git rev-parse --verify HEAD) + echo "--- IMPORTANT IMPORTANT IMPORTANT ---" + # move back the test runner implementation if there's no file. + if [ ! -f .kokoro/test-samples-impl.sh ]; then + cp "${TMPDIR}/test-samples-impl.sh" .kokoro/test-samples-impl.sh fi +fi - if [[ $EXIT -ne 0 ]]; then - RTN=1 - echo -e "\n Testing failed: Nox returned a non-zero exit code. \n" - else - echo -e "\n Testing completed.\n" - fi - -done -cd "$ROOT" - -# Workaround for Kokoro permissions issue: delete secrets -rm testing/{test-env.sh,client-secrets.json,service-account.json} - -exit "$RTN" +exec .kokoro/test-samples-impl.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 32302e4883..a9024b15d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,6 @@ repos: hooks: - id: black - repo: https://gitlab.com/pycqa/flake8 - rev: 3.9.0 + rev: 3.8.4 hooks: - id: flake8 diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 3cab430ce1..66216c172d 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -70,9 +70,14 @@ We use `nox `__ to instrument our tests. - To test your changes, run unit tests with ``nox``:: $ nox -s unit-2.7 - $ nox -s unit-3.7 + $ nox -s unit-3.8 $ ... +- Args to pytest can be passed through the nox command separated by a `--`. For + example, to run a single test:: + + $ nox -s unit-3.8 -- -k + .. note:: The unit tests and system tests are described in the @@ -93,8 +98,12 @@ On Debian/Ubuntu:: ************ Coding Style ************ +- We use the automatic code formatter ``black``. You can run it using + the nox session ``blacken``. This will eliminate many lint errors. Run via:: + + $ nox -s blacken -- PEP8 compliance, with exceptions defined in the linter configuration. +- PEP8 compliance is required, with exceptions defined in the linter configuration. If you have ``nox`` installed, you can test that you have not introduced any non-compliant code via:: @@ -133,13 +142,18 @@ Running System Tests - To run system tests, you can execute:: - $ nox -s system-3.7 + # Run all system tests + $ nox -s system-3.8 $ nox -s system-2.7 + # Run a single system test + $ nox -s system-3.8 -- -k + + .. note:: System tests are only configured to run under Python 2.7 and - Python 3.7. For expediency, we do not run them in older versions + Python 3.8. For expediency, we do not run them in older versions of Python 3. This alone will not run the tests. You'll need to change some local diff --git a/MANIFEST.in b/MANIFEST.in index e9e29d1203..e783f4c620 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -16,10 +16,10 @@ # Generated by synthtool. DO NOT EDIT! include README.rst LICENSE -recursive-include google *.json *.proto +recursive-include google *.json *.proto py.typed recursive-include tests * global-exclude *.py[co] global-exclude __pycache__ # Exclude scripts for samples readmegen -prune scripts/readme-gen \ No newline at end of file +prune scripts/readme-gen diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index dd8c8a41bc..6e4f84c707 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -11,3 +11,4 @@ Services for Google Cloud Aiplatform v1beta1 API pipeline_service prediction_service specialist_pool_service + vizier_service diff --git a/docs/aiplatform_v1beta1/vizier_service.rst b/docs/aiplatform_v1beta1/vizier_service.rst new file mode 100644 index 0000000000..7235400038 --- /dev/null +++ b/docs/aiplatform_v1beta1/vizier_service.rst @@ -0,0 +1,11 @@ +VizierService +------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.vizier_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.vizier_service.pagers + :members: + :inherited-members: diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py index 8a60b2e36c..a15aa2c041 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py @@ -30,7 +30,6 @@ AutoMlImageSegmentationInputs, AutoMlImageSegmentationMetadata, ) -from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig from .automl_tables import ( AutoMlTables, AutoMlTablesInputs, @@ -60,6 +59,7 @@ AutoMlVideoObjectTracking, AutoMlVideoObjectTrackingInputs, ) +from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig __all__ = ( "AutoMlImageClassification", @@ -71,7 +71,6 @@ "AutoMlImageSegmentation", "AutoMlImageSegmentationInputs", "AutoMlImageSegmentationMetadata", - "ExportEvaluatedDataItemsConfig", "AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata", @@ -87,4 +86,5 @@ "AutoMlVideoClassificationInputs", "AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py index 32ad0e9ff2..2f514ac4ed 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types.image_classification import ( ImageClassificationPredictionInstance, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py index 89be6318f8..ae3935d387 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_action_recognition.py @@ -48,8 +48,8 @@ class VideoActionRecognitionPredictionInstance(proto.Message): Expressed as a number of seconds as measured from the start of the video, with "s" appended at the end. Fractions are allowed, up to a - microsecond precision, and "Infinity" is - allowed, which means the end of the video. + microsecond precision, and "inf" or "Infinity" + is allowed, which means the end of the video. """ content = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py index 41ab3bc217..2f944bb99e 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_classification.py @@ -48,8 +48,8 @@ class VideoClassificationPredictionInstance(proto.Message): Expressed as a number of seconds as measured from the start of the video, with "s" appended at the end. Fractions are allowed, up to a - microsecond precision, and "Infinity" is - allowed, which means the end of the video. + microsecond precision, and "inf" or "Infinity" + is allowed, which means the end of the video. """ content = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py index 3729c14816..e635b5174b 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/video_object_tracking.py @@ -48,8 +48,8 @@ class VideoObjectTrackingPredictionInstance(proto.Message): Expressed as a number of seconds as measured from the start of the video, with "s" appended at the end. Fractions are allowed, up to a - microsecond precision, and "Infinity" is - allowed, which means the end of the video. + microsecond precision, and "inf" or "Infinity" + is allowed, which means the end of the video. """ content = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py index 4a410f3904..dc7cd58e9a 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types.image_classification import ( ImageClassificationPredictionParams, ) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py index 159824217b..d5f2762504 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # + from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.classification import ( ClassificationPredictionResult, ) @@ -35,9 +36,6 @@ from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.text_sentiment import ( TextSentimentPredictionResult, ) -from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.time_series_forecasting import ( - TimeSeriesForecastingPredictionResult, -) from google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.video_action_recognition import ( VideoActionRecognitionPredictionResult, ) @@ -56,7 +54,6 @@ "TabularRegressionPredictionResult", "TextExtractionPredictionResult", "TextSentimentPredictionResult", - "TimeSeriesForecastingPredictionResult", "VideoActionRecognitionPredictionResult", "VideoClassificationPredictionResult", "VideoObjectTrackingPredictionResult", diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py index 37066cd8b3..91fae5a3b1 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/__init__.py @@ -22,7 +22,6 @@ from .types.tabular_regression import TabularRegressionPredictionResult from .types.text_extraction import TextExtractionPredictionResult from .types.text_sentiment import TextSentimentPredictionResult -from .types.time_series_forecasting import TimeSeriesForecastingPredictionResult from .types.video_action_recognition import VideoActionRecognitionPredictionResult from .types.video_classification import VideoClassificationPredictionResult from .types.video_object_tracking import VideoObjectTrackingPredictionResult @@ -35,7 +34,6 @@ "TabularRegressionPredictionResult", "TextExtractionPredictionResult", "TextSentimentPredictionResult", - "TimeSeriesForecastingPredictionResult", "VideoActionRecognitionPredictionResult", "VideoClassificationPredictionResult", "VideoObjectTrackingPredictionResult", diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py index 5ec1ed095e..a0fd2058e0 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py @@ -22,7 +22,6 @@ from .tabular_regression import TabularRegressionPredictionResult from .text_extraction import TextExtractionPredictionResult from .text_sentiment import TextSentimentPredictionResult -from .time_series_forecasting import TimeSeriesForecastingPredictionResult from .video_action_recognition import VideoActionRecognitionPredictionResult from .video_classification import VideoClassificationPredictionResult from .video_object_tracking import VideoObjectTrackingPredictionResult @@ -35,7 +34,6 @@ "TabularRegressionPredictionResult", "TextExtractionPredictionResult", "TextSentimentPredictionResult", - "TimeSeriesForecastingPredictionResult", "VideoActionRecognitionPredictionResult", "VideoClassificationPredictionResult", "VideoObjectTrackingPredictionResult", diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py index 195dea6f79..ffd6fb9380 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_segmentation.py @@ -28,7 +28,7 @@ class ImageSegmentationPredictionResult(proto.Message): r"""Prediction output format for Image Segmentation. Attributes: - category_mask (bytes): + category_mask (str): A PNG image where each pixel in the mask represents the category in which the pixel in the original image was predicted to belong to. @@ -39,7 +39,7 @@ class ImageSegmentationPredictionResult(proto.Message): likely category and if none of the categories reach the confidence threshold, the pixel will be marked as background. - confidence_mask (bytes): + confidence_mask (str): A one channel image which is encoded as an 8bit lossless PNG. The size of the image will be the same as the original image. For a specific @@ -49,9 +49,9 @@ class ImageSegmentationPredictionResult(proto.Message): confidence and white means complete confidence. """ - category_mask = proto.Field(proto.BYTES, number=1) + category_mask = proto.Field(proto.STRING, number=1) - confidence_mask = proto.Field(proto.BYTES, number=2) + confidence_mask = proto.Field(proto.STRING, number=2) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py index 392fae649e..d632ef9609 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/__init__.py @@ -14,15 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_forecasting import ( - AutoMlForecasting, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_forecasting import ( - AutoMlForecastingInputs, -) -from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_forecasting import ( - AutoMlForecastingMetadata, -) + from google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.automl_image_classification import ( AutoMlImageClassification, ) @@ -100,9 +92,6 @@ ) __all__ = ( - "AutoMlForecasting", - "AutoMlForecastingInputs", - "AutoMlForecastingMetadata", "AutoMlImageClassification", "AutoMlImageClassificationInputs", "AutoMlImageClassificationMetadata", diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py index 346ea62686..34958e5add 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/__init__.py @@ -15,9 +15,6 @@ # limitations under the License. # -from .types.automl_forecasting import AutoMlForecasting -from .types.automl_forecasting import AutoMlForecastingInputs -from .types.automl_forecasting import AutoMlForecastingMetadata from .types.automl_image_classification import AutoMlImageClassification from .types.automl_image_classification import AutoMlImageClassificationInputs from .types.automl_image_classification import AutoMlImageClassificationMetadata @@ -46,10 +43,6 @@ __all__ = ( - "AutoMlForecasting", - "AutoMlForecastingInputs", - "AutoMlForecastingMetadata", - "AutoMlImageClassification", "AutoMlImageClassificationInputs", "AutoMlImageClassificationMetadata", "AutoMlImageObjectDetection", @@ -74,4 +67,5 @@ "AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs", "ExportEvaluatedDataItemsConfig", + "AutoMlImageClassification", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py index 3853ca87a9..a15aa2c041 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py @@ -15,12 +15,6 @@ # limitations under the License. # -from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig -from .automl_forecasting import ( - AutoMlForecasting, - AutoMlForecastingInputs, - AutoMlForecastingMetadata, -) from .automl_image_classification import ( AutoMlImageClassification, AutoMlImageClassificationInputs, @@ -65,12 +59,9 @@ AutoMlVideoObjectTracking, AutoMlVideoObjectTrackingInputs, ) +from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig __all__ = ( - "ExportEvaluatedDataItemsConfig", - "AutoMlForecasting", - "AutoMlForecastingInputs", - "AutoMlForecastingMetadata", "AutoMlImageClassification", "AutoMlImageClassificationInputs", "AutoMlImageClassificationMetadata", @@ -95,4 +86,5 @@ "AutoMlVideoClassificationInputs", "AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", ) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py index 22c199e7f5..014df43b2f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py @@ -82,6 +82,7 @@ class ModelType(proto.Enum): MODEL_TYPE_UNSPECIFIED = 0 CLOUD_HIGH_ACCURACY_1 = 1 CLOUD_LOW_ACCURACY_1 = 2 + MOBILE_TF_LOW_LATENCY_1 = 3 model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py index 51195eb327..e1c12eb46c 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py @@ -51,6 +51,7 @@ class ModelType(proto.Enum): MODEL_TYPE_UNSPECIFIED = 0 CLOUD = 1 MOBILE_VERSATILE_1 = 2 + MOBILE_JETSON_VERSATILE_1 = 3 model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py index 29bc547adf..9a6195fec2 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/export_evaluated_data_items_config.py @@ -30,18 +30,19 @@ class ExportEvaluatedDataItemsConfig(proto.Message): Attributes: destination_bigquery_uri (str): - URI of desired destination BigQuery table. If not specified, - then results are exported to the following auto-created - BigQuery table: + URI of desired destination BigQuery table. Expected format: + bq://:: + + If not specified, then results are exported to the following + auto-created BigQuery table: :export_evaluated_examples__.evaluated_examples override_existing_table (bool): If true and an export destination is specified, then the contents of the destination - will be overwritten. Otherwise, if the export + are overwritten. Otherwise, if the export destination already exists, then the export - operation will not trigger and a failure - response is returned. + operation fails. """ destination_bigquery_uri = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py index d5b56b54f5..a07ee32dfd 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -97,8 +97,36 @@ class DatasetServiceAsyncClient: DatasetServiceClient.parse_common_location_path ) - from_service_account_info = DatasetServiceClient.from_service_account_info - from_service_account_file = DatasetServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatasetServiceAsyncClient: The constructed client. + """ + return DatasetServiceClient.from_service_account_info.__func__(DatasetServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatasetServiceAsyncClient: The constructed client. + """ + return DatasetServiceClient.from_service_account_file.__func__(DatasetServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py index e545dbe56e..160a2049b8 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -926,9 +926,8 @@ def import_data( if name is not None: request.name = name - - if import_configs: - request.import_configs.extend(import_configs) + if import_configs is not None: + request.import_configs = import_configs # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py index f195ca3308..c3f8265b6e 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import annotation from google.cloud.aiplatform_v1.types import data_item diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py index e5a54388cb..20a01deb79 100644 --- a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py @@ -248,8 +248,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py index e36aa6dfde..13f099328b 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -87,8 +87,36 @@ class EndpointServiceAsyncClient: EndpointServiceClient.parse_common_location_path ) - from_service_account_info = EndpointServiceClient.from_service_account_info - from_service_account_file = EndpointServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EndpointServiceAsyncClient: The constructed client. + """ + return EndpointServiceClient.from_service_account_info.__func__(EndpointServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EndpointServiceAsyncClient: The constructed client. + """ + return EndpointServiceClient.from_service_account_file.__func__(EndpointServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py index 1316effa58..de54b0b9b5 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -903,9 +903,8 @@ def deploy_model( request.endpoint = endpoint if deployed_model is not None: request.deployed_model = deployed_model - - if traffic_split: - request.traffic_split.update(traffic_split) + if traffic_split is not None: + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -1022,9 +1021,8 @@ def undeploy_model( request.endpoint = endpoint if deployed_model_id is not None: request.deployed_model_id = deployed_model_id - - if traffic_split: - request.traffic_split.update(traffic_split) + if traffic_split is not None: + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py index 01ebccdec3..c22df91c8c 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import endpoint from google.cloud.aiplatform_v1.types import endpoint_service diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py index f0b8b32de1..d2c13c3fe7 100644 --- a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -247,8 +247,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py index 689cb131ea..e253bcc5d6 100644 --- a/google/cloud/aiplatform_v1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -116,8 +116,36 @@ class JobServiceAsyncClient: JobServiceClient.parse_common_location_path ) - from_service_account_info = JobServiceClient.from_service_account_info - from_service_account_file = JobServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + JobServiceAsyncClient: The constructed client. + """ + return JobServiceClient.from_service_account_info.__func__(JobServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + JobServiceAsyncClient: The constructed client. + """ + return JobServiceClient.from_service_account_file.__func__(JobServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/job_service/pagers.py b/google/cloud/aiplatform_v1/services/job_service/pagers.py index b5a0f4b929..35d679b6ad 100644 --- a/google/cloud/aiplatform_v1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/job_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import batch_prediction_job from google.cloud.aiplatform_v1.types import custom_job diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py index 139aaf3345..a9c90ecdaa 100644 --- a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py @@ -260,8 +260,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1/services/migration_service/async_client.py index fcb1d23da7..e7f45eeaf5 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/async_client.py @@ -96,8 +96,36 @@ class MigrationServiceAsyncClient: MigrationServiceClient.parse_common_location_path ) - from_service_account_info = MigrationServiceClient.from_service_account_info - from_service_account_file = MigrationServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MigrationServiceAsyncClient: The constructed client. + """ + return MigrationServiceClient.from_service_account_info.__func__(MigrationServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MigrationServiceAsyncClient: The constructed client. + """ + return MigrationServiceClient.from_service_account_file.__func__(MigrationServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 3ed18e0fa8..0a23f262c2 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -612,9 +612,8 @@ def batch_migrate_resources( if parent is not None: request.parent = parent - - if migrate_resource_requests: - request.migrate_resource_requests.extend(migrate_resource_requests) + if migrate_resource_requests is not None: + request.migrate_resource_requests = migrate_resource_requests # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1/services/migration_service/pagers.py index b7d9f4ae44..02a46451df 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/migration_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import migratable_resource from google.cloud.aiplatform_v1.types import migration_service diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py index 820a38a028..f11d72386d 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py @@ -249,8 +249,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py index 123b922019..687c22455a 100644 --- a/google/cloud/aiplatform_v1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -101,8 +101,36 @@ class ModelServiceAsyncClient: ModelServiceClient.parse_common_location_path ) - from_service_account_info = ModelServiceClient.from_service_account_info - from_service_account_file = ModelServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_info.__func__(ModelServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_file.__func__(ModelServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/model_service/pagers.py b/google/cloud/aiplatform_v1/services/model_service/pagers.py index be652f745f..d01f0057c1 100644 --- a/google/cloud/aiplatform_v1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/model_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import model from google.cloud.aiplatform_v1.types import model_evaluation diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py index 90dcfd008d..b6f2efb427 100644 --- a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py @@ -251,8 +251,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py index 95c7d8a176..fc7337a7a3 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -94,8 +94,36 @@ class PipelineServiceAsyncClient: PipelineServiceClient.parse_common_location_path ) - from_service_account_info = PipelineServiceClient.from_service_account_info - from_service_account_file = PipelineServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PipelineServiceAsyncClient: The constructed client. + """ + return PipelineServiceClient.from_service_account_info.__func__(PipelineServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PipelineServiceAsyncClient: The constructed client. + """ + return PipelineServiceClient.from_service_account_file.__func__(PipelineServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py index 0f3503ff5a..987c37dba2 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import pipeline_service from google.cloud.aiplatform_v1.types import training_pipeline diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py index 818144f008..b7d20db080 100644 --- a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py @@ -250,8 +250,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py index c0ab09622c..cc6d011e88 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -76,8 +76,36 @@ class PredictionServiceAsyncClient: PredictionServiceClient.parse_common_location_path ) - from_service_account_info = PredictionServiceClient.from_service_account_info - from_service_account_file = PredictionServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceAsyncClient: The constructed client. + """ + return PredictionServiceClient.from_service_account_info.__func__(PredictionServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceAsyncClient: The constructed client. + """ + return PredictionServiceClient.from_service_account_file.__func__(PredictionServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py index 55c52b48f4..029fb851b8 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -432,12 +432,11 @@ def predict( if endpoint is not None: request.endpoint = endpoint + if instances is not None: + request.instances.extend(instances) if parameters is not None: request.parameters = parameters - if instances: - request.instances.extend(instances) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.predict] diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py index 4fcfe5b442..86aef5e81a 100644 --- a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -244,8 +244,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py index 496f6aa319..57e2b8a0a7 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -95,8 +95,36 @@ class SpecialistPoolServiceAsyncClient: SpecialistPoolServiceClient.parse_common_location_path ) - from_service_account_info = SpecialistPoolServiceClient.from_service_account_info - from_service_account_file = SpecialistPoolServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + SpecialistPoolServiceAsyncClient: The constructed client. + """ + return SpecialistPoolServiceClient.from_service_account_info.__func__(SpecialistPoolServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + SpecialistPoolServiceAsyncClient: The constructed client. + """ + return SpecialistPoolServiceClient.from_service_account_file.__func__(SpecialistPoolServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py index b55e53169e..e64a827049 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1.types import specialist_pool from google.cloud.aiplatform_v1.types import specialist_pool_service diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py index c9895648d2..cb8904bc07 100644 --- a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py @@ -253,8 +253,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index f073d451fe..6d7c9ca42f 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -15,347 +15,347 @@ # limitations under the License. # -from .user_action_reference import UserActionReference from .annotation import Annotation from .annotation_spec import AnnotationSpec -from .completion_stats import CompletionStats -from .encryption_spec import EncryptionSpec -from .io import ( - GcsSource, - GcsDestination, - BigQuerySource, - BigQueryDestination, - ContainerRegistryDestination, -) -from .machine_resources import ( - MachineSpec, - DedicatedResources, - AutomaticResources, - BatchDedicatedResources, - ResourcesConsumed, - DiskSpec, -) -from .manual_batch_tuning_parameters import ManualBatchTuningParameters from .batch_prediction_job import BatchPredictionJob -from .env_var import EnvVar +from .completion_stats import CompletionStats from .custom_job import ( + ContainerSpec, CustomJob, CustomJobSpec, - WorkerPoolSpec, - ContainerSpec, PythonPackageSpec, Scheduling, + WorkerPoolSpec, ) from .data_item import DataItem -from .specialist_pool import SpecialistPool from .data_labeling_job import ( - DataLabelingJob, ActiveLearningConfig, + DataLabelingJob, SampleConfig, TrainingConfig, ) from .dataset import ( Dataset, - ImportDataConfig, ExportDataConfig, -) -from .operation import ( - GenericOperationMetadata, - DeleteOperationMetadata, -) -from .deployed_model_ref import DeployedModelRef -from .model import ( - Model, - PredictSchemata, - ModelContainerSpec, - Port, -) -from .training_pipeline import ( - TrainingPipeline, - InputDataConfig, - FractionSplit, - FilterSplit, - PredefinedSplit, - TimestampSplit, + ImportDataConfig, ) from .dataset_service import ( - CreateDatasetRequest, CreateDatasetOperationMetadata, - GetDatasetRequest, - UpdateDatasetRequest, - ListDatasetsRequest, - ListDatasetsResponse, + CreateDatasetRequest, DeleteDatasetRequest, - ImportDataRequest, - ImportDataResponse, - ImportDataOperationMetadata, + ExportDataOperationMetadata, ExportDataRequest, ExportDataResponse, - ExportDataOperationMetadata, - ListDataItemsRequest, - ListDataItemsResponse, GetAnnotationSpecRequest, + GetDatasetRequest, + ImportDataOperationMetadata, + ImportDataRequest, + ImportDataResponse, ListAnnotationsRequest, ListAnnotationsResponse, + ListDataItemsRequest, + ListDataItemsResponse, + ListDatasetsRequest, + ListDatasetsResponse, + UpdateDatasetRequest, ) +from .deployed_model_ref import DeployedModelRef +from .encryption_spec import EncryptionSpec from .endpoint import ( - Endpoint, DeployedModel, + Endpoint, ) from .endpoint_service import ( - CreateEndpointRequest, CreateEndpointOperationMetadata, - GetEndpointRequest, - ListEndpointsRequest, - ListEndpointsResponse, - UpdateEndpointRequest, + CreateEndpointRequest, DeleteEndpointRequest, + DeployModelOperationMetadata, DeployModelRequest, DeployModelResponse, - DeployModelOperationMetadata, + GetEndpointRequest, + ListEndpointsRequest, + ListEndpointsResponse, + UndeployModelOperationMetadata, UndeployModelRequest, UndeployModelResponse, - UndeployModelOperationMetadata, -) -from .study import ( - Trial, - StudySpec, - Measurement, + UpdateEndpointRequest, ) +from .env_var import EnvVar from .hyperparameter_tuning_job import HyperparameterTuningJob +from .io import ( + BigQueryDestination, + BigQuerySource, + ContainerRegistryDestination, + GcsDestination, + GcsSource, +) from .job_service import ( + CancelBatchPredictionJobRequest, + CancelCustomJobRequest, + CancelDataLabelingJobRequest, + CancelHyperparameterTuningJobRequest, + CreateBatchPredictionJobRequest, CreateCustomJobRequest, + CreateDataLabelingJobRequest, + CreateHyperparameterTuningJobRequest, + DeleteBatchPredictionJobRequest, + DeleteCustomJobRequest, + DeleteDataLabelingJobRequest, + DeleteHyperparameterTuningJobRequest, + GetBatchPredictionJobRequest, GetCustomJobRequest, + GetDataLabelingJobRequest, + GetHyperparameterTuningJobRequest, + ListBatchPredictionJobsRequest, + ListBatchPredictionJobsResponse, ListCustomJobsRequest, ListCustomJobsResponse, - DeleteCustomJobRequest, - CancelCustomJobRequest, - CreateDataLabelingJobRequest, - GetDataLabelingJobRequest, ListDataLabelingJobsRequest, ListDataLabelingJobsResponse, - DeleteDataLabelingJobRequest, - CancelDataLabelingJobRequest, - CreateHyperparameterTuningJobRequest, - GetHyperparameterTuningJobRequest, ListHyperparameterTuningJobsRequest, ListHyperparameterTuningJobsResponse, - DeleteHyperparameterTuningJobRequest, - CancelHyperparameterTuningJobRequest, - CreateBatchPredictionJobRequest, - GetBatchPredictionJobRequest, - ListBatchPredictionJobsRequest, - ListBatchPredictionJobsResponse, - DeleteBatchPredictionJobRequest, - CancelBatchPredictionJobRequest, ) +from .machine_resources import ( + AutomaticResources, + BatchDedicatedResources, + DedicatedResources, + DiskSpec, + MachineSpec, + ResourcesConsumed, +) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters from .migratable_resource import MigratableResource from .migration_service import ( - SearchMigratableResourcesRequest, - SearchMigratableResourcesResponse, + BatchMigrateResourcesOperationMetadata, BatchMigrateResourcesRequest, - MigrateResourceRequest, BatchMigrateResourcesResponse, + MigrateResourceRequest, MigrateResourceResponse, - BatchMigrateResourcesOperationMetadata, + SearchMigratableResourcesRequest, + SearchMigratableResourcesResponse, +) +from .model import ( + Model, + ModelContainerSpec, + Port, + PredictSchemata, ) from .model_evaluation import ModelEvaluation from .model_evaluation_slice import ModelEvaluationSlice from .model_service import ( - UploadModelRequest, - UploadModelOperationMetadata, - UploadModelResponse, - GetModelRequest, - ListModelsRequest, - ListModelsResponse, - UpdateModelRequest, DeleteModelRequest, - ExportModelRequest, ExportModelOperationMetadata, + ExportModelRequest, ExportModelResponse, GetModelEvaluationRequest, - ListModelEvaluationsRequest, - ListModelEvaluationsResponse, GetModelEvaluationSliceRequest, + GetModelRequest, ListModelEvaluationSlicesRequest, ListModelEvaluationSlicesResponse, + ListModelEvaluationsRequest, + ListModelEvaluationsResponse, + ListModelsRequest, + ListModelsResponse, + UpdateModelRequest, + UploadModelOperationMetadata, + UploadModelRequest, + UploadModelResponse, +) +from .operation import ( + DeleteOperationMetadata, + GenericOperationMetadata, ) from .pipeline_service import ( + CancelTrainingPipelineRequest, CreateTrainingPipelineRequest, + DeleteTrainingPipelineRequest, GetTrainingPipelineRequest, ListTrainingPipelinesRequest, ListTrainingPipelinesResponse, - DeleteTrainingPipelineRequest, - CancelTrainingPipelineRequest, ) from .prediction_service import ( PredictRequest, PredictResponse, ) +from .specialist_pool import SpecialistPool from .specialist_pool_service import ( - CreateSpecialistPoolRequest, CreateSpecialistPoolOperationMetadata, + CreateSpecialistPoolRequest, + DeleteSpecialistPoolRequest, GetSpecialistPoolRequest, ListSpecialistPoolsRequest, ListSpecialistPoolsResponse, - DeleteSpecialistPoolRequest, - UpdateSpecialistPoolRequest, UpdateSpecialistPoolOperationMetadata, + UpdateSpecialistPoolRequest, +) +from .study import ( + Measurement, + StudySpec, + Trial, ) +from .training_pipeline import ( + FilterSplit, + FractionSplit, + InputDataConfig, + PredefinedSplit, + TimestampSplit, + TrainingPipeline, +) +from .user_action_reference import UserActionReference __all__ = ( "AcceleratorType", - "UserActionReference", "Annotation", "AnnotationSpec", - "CompletionStats", - "EncryptionSpec", - "GcsSource", - "GcsDestination", - "BigQuerySource", - "BigQueryDestination", - "ContainerRegistryDestination", - "JobState", - "MachineSpec", - "DedicatedResources", - "AutomaticResources", - "BatchDedicatedResources", - "ResourcesConsumed", - "DiskSpec", - "ManualBatchTuningParameters", "BatchPredictionJob", - "EnvVar", + "CompletionStats", + "ContainerSpec", "CustomJob", "CustomJobSpec", - "WorkerPoolSpec", - "ContainerSpec", "PythonPackageSpec", "Scheduling", + "WorkerPoolSpec", "DataItem", - "SpecialistPool", - "DataLabelingJob", "ActiveLearningConfig", + "DataLabelingJob", "SampleConfig", "TrainingConfig", "Dataset", - "ImportDataConfig", "ExportDataConfig", - "GenericOperationMetadata", - "DeleteOperationMetadata", - "DeployedModelRef", - "Model", - "PredictSchemata", - "ModelContainerSpec", - "Port", - "PipelineState", - "TrainingPipeline", - "InputDataConfig", - "FractionSplit", - "FilterSplit", - "PredefinedSplit", - "TimestampSplit", - "CreateDatasetRequest", + "ImportDataConfig", "CreateDatasetOperationMetadata", - "GetDatasetRequest", - "UpdateDatasetRequest", - "ListDatasetsRequest", - "ListDatasetsResponse", + "CreateDatasetRequest", "DeleteDatasetRequest", - "ImportDataRequest", - "ImportDataResponse", - "ImportDataOperationMetadata", + "ExportDataOperationMetadata", "ExportDataRequest", "ExportDataResponse", - "ExportDataOperationMetadata", - "ListDataItemsRequest", - "ListDataItemsResponse", "GetAnnotationSpecRequest", + "GetDatasetRequest", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", "ListAnnotationsRequest", "ListAnnotationsResponse", - "Endpoint", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "UpdateDatasetRequest", + "DeployedModelRef", + "EncryptionSpec", "DeployedModel", - "CreateEndpointRequest", + "Endpoint", "CreateEndpointOperationMetadata", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UpdateEndpointRequest", + "CreateEndpointRequest", "DeleteEndpointRequest", + "DeployModelOperationMetadata", "DeployModelRequest", "DeployModelResponse", - "DeployModelOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UndeployModelOperationMetadata", "UndeployModelRequest", "UndeployModelResponse", - "UndeployModelOperationMetadata", - "Trial", - "StudySpec", - "Measurement", + "UpdateEndpointRequest", + "EnvVar", "HyperparameterTuningJob", + "BigQueryDestination", + "BigQuerySource", + "ContainerRegistryDestination", + "GcsDestination", + "GcsSource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteHyperparameterTuningJobRequest", + "GetBatchPredictionJobRequest", "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetHyperparameterTuningJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", "ListCustomJobsRequest", "ListCustomJobsResponse", - "DeleteCustomJobRequest", - "CancelCustomJobRequest", - "CreateDataLabelingJobRequest", - "GetDataLabelingJobRequest", "ListDataLabelingJobsRequest", "ListDataLabelingJobsResponse", - "DeleteDataLabelingJobRequest", - "CancelDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "GetHyperparameterTuningJobRequest", "ListHyperparameterTuningJobsRequest", "ListHyperparameterTuningJobsResponse", - "DeleteHyperparameterTuningJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "GetBatchPredictionJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "DeleteBatchPredictionJobRequest", - "CancelBatchPredictionJobRequest", + "JobState", + "AutomaticResources", + "BatchDedicatedResources", + "DedicatedResources", + "DiskSpec", + "MachineSpec", + "ResourcesConsumed", + "ManualBatchTuningParameters", "MigratableResource", - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", + "BatchMigrateResourcesOperationMetadata", "BatchMigrateResourcesRequest", - "MigrateResourceRequest", "BatchMigrateResourcesResponse", + "MigrateResourceRequest", "MigrateResourceResponse", - "BatchMigrateResourcesOperationMetadata", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "Model", + "ModelContainerSpec", + "Port", + "PredictSchemata", "ModelEvaluation", "ModelEvaluationSlice", - "UploadModelRequest", - "UploadModelOperationMetadata", - "UploadModelResponse", - "GetModelRequest", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", "DeleteModelRequest", - "ExportModelRequest", "ExportModelOperationMetadata", + "ExportModelRequest", "ExportModelResponse", "GetModelEvaluationRequest", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", "GetModelEvaluationSliceRequest", + "GetModelRequest", "ListModelEvaluationSlicesRequest", "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "DeleteOperationMetadata", + "GenericOperationMetadata", + "CancelTrainingPipelineRequest", "CreateTrainingPipelineRequest", + "DeleteTrainingPipelineRequest", "GetTrainingPipelineRequest", "ListTrainingPipelinesRequest", "ListTrainingPipelinesResponse", - "DeleteTrainingPipelineRequest", - "CancelTrainingPipelineRequest", + "PipelineState", "PredictRequest", "PredictResponse", - "CreateSpecialistPoolRequest", + "SpecialistPool", "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "DeleteSpecialistPoolRequest", "GetSpecialistPoolRequest", "ListSpecialistPoolsRequest", "ListSpecialistPoolsResponse", - "DeleteSpecialistPoolRequest", - "UpdateSpecialistPoolRequest", "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "Measurement", + "StudySpec", + "Trial", + "FilterSplit", + "FractionSplit", + "InputDataConfig", + "PredefinedSplit", + "TimestampSplit", + "TrainingPipeline", + "UserActionReference", ) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index b76824eac3..621f1e96f8 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -23,6 +23,7 @@ from .services.pipeline_service import PipelineServiceClient from .services.prediction_service import PredictionServiceClient from .services.specialist_pool_service import SpecialistPoolServiceClient +from .services.vizier_service import VizierServiceClient from .types.accelerator_type import AcceleratorType from .types.annotation import Annotation from .types.annotation_spec import AnnotationSpec @@ -123,6 +124,7 @@ from .types.job_service import ListHyperparameterTuningJobsResponse from .types.job_state import JobState from .types.machine_resources import AutomaticResources +from .types.machine_resources import AutoscalingMetricSpec from .types.machine_resources import BatchDedicatedResources from .types.machine_resources import DedicatedResources from .types.machine_resources import DiskSpec @@ -183,6 +185,7 @@ from .types.specialist_pool_service import UpdateSpecialistPoolOperationMetadata from .types.specialist_pool_service import UpdateSpecialistPoolRequest from .types.study import Measurement +from .types.study import Study from .types.study import StudySpec from .types.study import Trial from .types.training_pipeline import FilterSplit @@ -192,15 +195,39 @@ from .types.training_pipeline import TimestampSplit from .types.training_pipeline import TrainingPipeline from .types.user_action_reference import UserActionReference +from .types.vizier_service import AddTrialMeasurementRequest +from .types.vizier_service import CheckTrialEarlyStoppingStateMetatdata +from .types.vizier_service import CheckTrialEarlyStoppingStateRequest +from .types.vizier_service import CheckTrialEarlyStoppingStateResponse +from .types.vizier_service import CompleteTrialRequest +from .types.vizier_service import CreateStudyRequest +from .types.vizier_service import CreateTrialRequest +from .types.vizier_service import DeleteStudyRequest +from .types.vizier_service import DeleteTrialRequest +from .types.vizier_service import GetStudyRequest +from .types.vizier_service import GetTrialRequest +from .types.vizier_service import ListOptimalTrialsRequest +from .types.vizier_service import ListOptimalTrialsResponse +from .types.vizier_service import ListStudiesRequest +from .types.vizier_service import ListStudiesResponse +from .types.vizier_service import ListTrialsRequest +from .types.vizier_service import ListTrialsResponse +from .types.vizier_service import LookupStudyRequest +from .types.vizier_service import StopTrialRequest +from .types.vizier_service import SuggestTrialsMetadata +from .types.vizier_service import SuggestTrialsRequest +from .types.vizier_service import SuggestTrialsResponse __all__ = ( "AcceleratorType", "ActiveLearningConfig", + "AddTrialMeasurementRequest", "Annotation", "AnnotationSpec", "Attribution", "AutomaticResources", + "AutoscalingMetricSpec", "BatchDedicatedResources", "BatchMigrateResourcesOperationMetadata", "BatchMigrateResourcesRequest", @@ -213,6 +240,10 @@ "CancelDataLabelingJobRequest", "CancelHyperparameterTuningJobRequest", "CancelTrainingPipelineRequest", + "CheckTrialEarlyStoppingStateMetatdata", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CompleteTrialRequest", "CompletionStats", "ContainerRegistryDestination", "ContainerSpec", @@ -226,7 +257,9 @@ "CreateHyperparameterTuningJobRequest", "CreateSpecialistPoolOperationMetadata", "CreateSpecialistPoolRequest", + "CreateStudyRequest", "CreateTrainingPipelineRequest", + "CreateTrialRequest", "CustomJob", "CustomJobSpec", "DataItem", @@ -243,7 +276,9 @@ "DeleteModelRequest", "DeleteOperationMetadata", "DeleteSpecialistPoolRequest", + "DeleteStudyRequest", "DeleteTrainingPipelineRequest", + "DeleteTrialRequest", "DeployModelOperationMetadata", "DeployModelRequest", "DeployModelResponse", @@ -286,7 +321,9 @@ "GetModelEvaluationSliceRequest", "GetModelRequest", "GetSpecialistPoolRequest", + "GetStudyRequest", "GetTrainingPipelineRequest", + "GetTrialRequest", "HyperparameterTuningJob", "ImportDataConfig", "ImportDataOperationMetadata", @@ -318,10 +355,17 @@ "ListModelEvaluationsResponse", "ListModelsRequest", "ListModelsResponse", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", "ListSpecialistPoolsRequest", "ListSpecialistPoolsResponse", + "ListStudiesRequest", + "ListStudiesResponse", "ListTrainingPipelinesRequest", "ListTrainingPipelinesResponse", + "ListTrialsRequest", + "ListTrialsResponse", + "LookupStudyRequest", "MachineSpec", "ManualBatchTuningParameters", "Measurement", @@ -352,7 +396,13 @@ "SearchMigratableResourcesResponse", "SmoothGradConfig", "SpecialistPool", + "SpecialistPoolServiceClient", + "StopTrialRequest", + "Study", "StudySpec", + "SuggestTrialsMetadata", + "SuggestTrialsRequest", + "SuggestTrialsResponse", "TimestampSplit", "TrainingConfig", "TrainingPipeline", @@ -371,5 +421,5 @@ "UserActionReference", "WorkerPoolSpec", "XraiAttribution", - "SpecialistPoolServiceClient", + "VizierServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 2915b4888d..d91df4b644 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -97,8 +97,36 @@ class DatasetServiceAsyncClient: DatasetServiceClient.parse_common_location_path ) - from_service_account_info = DatasetServiceClient.from_service_account_info - from_service_account_file = DatasetServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatasetServiceAsyncClient: The constructed client. + """ + return DatasetServiceClient.from_service_account_info.__func__(DatasetServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatasetServiceAsyncClient: The constructed client. + """ + return DatasetServiceClient.from_service_account_file.__func__(DatasetServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property @@ -889,7 +917,6 @@ async def get_annotation_spec( name (:class:`str`): Required. The name of the AnnotationSpec resource. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` This corresponds to the ``name`` field @@ -964,7 +991,6 @@ async def list_annotations( parent (:class:`str`): Required. The resource name of the DataItem to list Annotations from. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` This corresponds to the ``parent`` field diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 187b9c8f0c..37aecfc5e5 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -926,9 +926,8 @@ def import_data( if name is not None: request.name = name - - if import_configs: - request.import_configs.extend(import_configs) + if import_configs is not None: + request.import_configs = import_configs # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -1152,7 +1151,6 @@ def get_annotation_spec( name (str): Required. The name of the AnnotationSpec resource. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` This corresponds to the ``name`` field @@ -1228,7 +1226,6 @@ def list_annotations( parent (str): Required. The resource name of the DataItem to list Annotations from. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` This corresponds to the ``parent`` field diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index 4c5d248571..63560b32ba 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import data_item diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index b4fd90ee1f..4dae75d109 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -248,8 +248,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 43242b1148..05aa538225 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -87,8 +87,36 @@ class EndpointServiceAsyncClient: EndpointServiceClient.parse_common_location_path ) - from_service_account_info = EndpointServiceClient.from_service_account_info - from_service_account_file = EndpointServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EndpointServiceAsyncClient: The constructed client. + """ + return EndpointServiceClient.from_service_account_info.__func__(EndpointServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EndpointServiceAsyncClient: The constructed client. + """ + return EndpointServiceClient.from_service_account_file.__func__(EndpointServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 35f968b7c7..1fdf1e506e 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -903,9 +903,8 @@ def deploy_model( request.endpoint = endpoint if deployed_model is not None: request.deployed_model = deployed_model - - if traffic_split: - request.traffic_split.update(traffic_split) + if traffic_split is not None: + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -1022,9 +1021,8 @@ def undeploy_model( request.endpoint = endpoint if deployed_model_id is not None: request.deployed_model_id = deployed_model_id - - if traffic_split: - request.traffic_split.update(traffic_split) + if traffic_split is not None: + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 1ceb718df1..db3172bcef 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index e5b820de61..455ed12cf4 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -247,8 +247,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 4c22267c01..366cbf0f52 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -119,8 +119,36 @@ class JobServiceAsyncClient: JobServiceClient.parse_common_location_path ) - from_service_account_info = JobServiceClient.from_service_account_info - from_service_account_file = JobServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + JobServiceAsyncClient: The constructed client. + """ + return JobServiceClient.from_service_account_info.__func__(JobServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + JobServiceAsyncClient: The constructed client. + """ + return JobServiceClient.from_service_account_file.__func__(JobServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property @@ -709,7 +737,6 @@ async def get_data_labeling_job( [DataLabelingJobService.GetDataLabelingJob][]. name (:class:`str`): Required. The name of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` This corresponds to the ``name`` field @@ -867,7 +894,6 @@ async def delete_data_labeling_job( name (:class:`str`): Required. The name of the DataLabelingJob to be deleted. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` This corresponds to the ``name`` field @@ -963,7 +989,6 @@ async def cancel_data_labeling_job( [DataLabelingJobService.CancelDataLabelingJob][]. name (:class:`str`): Required. The name of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` This corresponds to the ``name`` field @@ -1117,7 +1142,6 @@ async def get_hyperparameter_tuning_job( name (:class:`str`): Required. The name of the HyperparameterTuningJob resource. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` This corresponds to the ``name`` field @@ -1277,7 +1301,6 @@ async def delete_hyperparameter_tuning_job( name (:class:`str`): Required. The name of the HyperparameterTuningJob resource to be deleted. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` This corresponds to the ``name`` field @@ -1386,7 +1409,6 @@ async def cancel_hyperparameter_tuning_job( name (:class:`str`): Required. The name of the HyperparameterTuningJob to cancel. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` This corresponds to the ``name`` field @@ -1543,7 +1565,6 @@ async def get_batch_prediction_job( name (:class:`str`): Required. The name of the BatchPredictionJob resource. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` This corresponds to the ``name`` field @@ -1706,7 +1727,6 @@ async def delete_batch_prediction_job( name (:class:`str`): Required. The name of the BatchPredictionJob resource to be deleted. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` This corresponds to the ``name`` field @@ -1813,7 +1833,6 @@ async def cancel_batch_prediction_job( name (:class:`str`): Required. The name of the BatchPredictionJob to cancel. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index 54a53f26db..81fa0d786f 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -1007,7 +1007,6 @@ def get_data_labeling_job( [DataLabelingJobService.GetDataLabelingJob][]. name (str): Required. The name of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` This corresponds to the ``name`` field @@ -1167,7 +1166,6 @@ def delete_data_labeling_job( name (str): Required. The name of the DataLabelingJob to be deleted. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` This corresponds to the ``name`` field @@ -1264,7 +1262,6 @@ def cancel_data_labeling_job( [DataLabelingJobService.CancelDataLabelingJob][]. name (str): Required. The name of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` This corresponds to the ``name`` field @@ -1422,7 +1419,6 @@ def get_hyperparameter_tuning_job( name (str): Required. The name of the HyperparameterTuningJob resource. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` This corresponds to the ``name`` field @@ -1588,7 +1584,6 @@ def delete_hyperparameter_tuning_job( name (str): Required. The name of the HyperparameterTuningJob resource to be deleted. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` This corresponds to the ``name`` field @@ -1700,7 +1695,6 @@ def cancel_hyperparameter_tuning_job( name (str): Required. The name of the HyperparameterTuningJob to cancel. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` This corresponds to the ``name`` field @@ -1863,7 +1857,6 @@ def get_batch_prediction_job( name (str): Required. The name of the BatchPredictionJob resource. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` This corresponds to the ``name`` field @@ -2030,7 +2023,6 @@ def delete_batch_prediction_job( name (str): Required. The name of the BatchPredictionJob resource to be deleted. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` This corresponds to the ``name`` field @@ -2140,7 +2132,6 @@ def cancel_batch_prediction_job( name (str): Required. The name of the BatchPredictionJob to cancel. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 845939923f..6c3da33d0a 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index c4efeaaf47..763f510e5b 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -262,8 +262,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index 7577e15d1c..c4db3f14d7 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -96,8 +96,36 @@ class MigrationServiceAsyncClient: MigrationServiceClient.parse_common_location_path ) - from_service_account_info = MigrationServiceClient.from_service_account_info - from_service_account_file = MigrationServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MigrationServiceAsyncClient: The constructed client. + """ + return MigrationServiceClient.from_service_account_info.__func__(MigrationServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MigrationServiceAsyncClient: The constructed client. + """ + return MigrationServiceClient.from_service_account_file.__func__(MigrationServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index 6d88a39046..501f21183f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -180,32 +180,32 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, location: str, dataset: str,) -> str: + def dataset_path(project: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def dataset_path(project: str, dataset: str,) -> str: + def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parse a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod @@ -612,9 +612,8 @@ def batch_migrate_resources( if parent is not None: request.parent = parent - - if migrate_resource_requests: - request.migrate_resource_requests.extend(migrate_resource_requests) + if migrate_resource_requests is not None: + request.migrate_resource_requests = migrate_resource_requests # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py index d231e61235..f0a1dfa43f 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import migratable_resource from google.cloud.aiplatform_v1beta1.types import migration_service diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index 0c5e1a080e..6789c12718 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -249,8 +249,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index c6b2b5f0fc..a901ead2b1 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -102,8 +102,36 @@ class ModelServiceAsyncClient: ModelServiceClient.parse_common_location_path ) - from_service_account_info = ModelServiceClient.from_service_account_info - from_service_account_file = ModelServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_info.__func__(ModelServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_file.__func__(ModelServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property @@ -715,7 +743,6 @@ async def get_model_evaluation( name (:class:`str`): Required. The name of the ModelEvaluation resource. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` This corresponds to the ``name`` field @@ -875,7 +902,6 @@ async def get_model_evaluation_slice( name (:class:`str`): Required. The name of the ModelEvaluationSlice resource. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` This corresponds to the ``name`` field @@ -952,7 +978,6 @@ async def list_model_evaluation_slices( parent (:class:`str`): Required. The resource name of the ModelEvaluation to list the ModelEvaluationSlices from. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` This corresponds to the ``parent`` field diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index d357ad3b9a..8b14e16e0b 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -983,7 +983,6 @@ def get_model_evaluation( name (str): Required. The name of the ModelEvaluation resource. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` This corresponds to the ``name`` field @@ -1145,7 +1144,6 @@ def get_model_evaluation_slice( name (str): Required. The name of the ModelEvaluationSlice resource. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` This corresponds to the ``name`` field @@ -1225,7 +1223,6 @@ def list_model_evaluation_slices( parent (str): Required. The resource name of the ModelEvaluation to list the ModelEvaluationSlices from. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` This corresponds to the ``parent`` field diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index 046f462b45..eb547a5f9f 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model_evaluation diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index 39452c0cd6..b401612b1c 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -251,8 +251,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index 170b5b8f59..063153700c 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -96,8 +96,36 @@ class PipelineServiceAsyncClient: PipelineServiceClient.parse_common_location_path ) - from_service_account_info = PipelineServiceClient.from_service_account_info - from_service_account_file = PipelineServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PipelineServiceAsyncClient: The constructed client. + """ + return PipelineServiceClient.from_service_account_info.__func__(PipelineServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PipelineServiceAsyncClient: The constructed client. + """ + return PipelineServiceClient.from_service_account_file.__func__(PipelineServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property @@ -268,7 +296,6 @@ async def get_training_pipeline( name (:class:`str`): Required. The name of the TrainingPipeline resource. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` This corresponds to the ``name`` field @@ -430,7 +457,6 @@ async def delete_training_pipeline( name (:class:`str`): Required. The name of the TrainingPipeline resource to be deleted. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` This corresponds to the ``name`` field @@ -538,7 +564,6 @@ async def cancel_training_pipeline( name (:class:`str`): Required. The name of the TrainingPipeline to cancel. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 25aa10df28..4efc2064b5 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -499,7 +499,6 @@ def get_training_pipeline( name (str): Required. The name of the TrainingPipeline resource. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` This corresponds to the ``name`` field @@ -663,7 +662,6 @@ def delete_training_pipeline( name (str): Required. The name of the TrainingPipeline resource to be deleted. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` This corresponds to the ``name`` field @@ -772,7 +770,6 @@ def cancel_training_pipeline( name (str): Required. The name of the TrainingPipeline to cancel. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index 1c8616e0a1..db2b4dd3a1 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 6aa6880fdb..83383d9e87 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -252,8 +252,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index 8eeae282f6..4d69a6635f 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -77,8 +77,36 @@ class PredictionServiceAsyncClient: PredictionServiceClient.parse_common_location_path ) - from_service_account_info = PredictionServiceClient.from_service_account_info - from_service_account_file = PredictionServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceAsyncClient: The constructed client. + """ + return PredictionServiceClient.from_service_account_info.__func__(PredictionServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceAsyncClient: The constructed client. + """ + return PredictionServiceClient.from_service_account_file.__func__(PredictionServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 89b90b5c12..042307eca1 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -433,12 +433,11 @@ def predict( if endpoint is not None: request.endpoint = endpoint + if instances is not None: + request.instances.extend(instances) if parameters is not None: request.parameters = parameters - if instances: - request.instances.extend(instances) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.predict] @@ -562,14 +561,13 @@ def explain( if endpoint is not None: request.endpoint = endpoint + if instances is not None: + request.instances.extend(instances) if parameters is not None: request.parameters = parameters if deployed_model_id is not None: request.deployed_model_id = deployed_model_id - if instances: - request.instances.extend(instances) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.explain] diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 53345710c6..f3b9be0c3d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -244,8 +244,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index c85d91436e..6907135b53 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -95,8 +95,36 @@ class SpecialistPoolServiceAsyncClient: SpecialistPoolServiceClient.parse_common_location_path ) - from_service_account_info = SpecialistPoolServiceClient.from_service_account_info - from_service_account_file = SpecialistPoolServiceClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + SpecialistPoolServiceAsyncClient: The constructed client. + """ + return SpecialistPoolServiceClient.from_service_account_info.__func__(SpecialistPoolServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + SpecialistPoolServiceAsyncClient: The constructed client. + """ + return SpecialistPoolServiceClient.from_service_account_file.__func__(SpecialistPoolServiceAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property @@ -279,7 +307,6 @@ async def get_specialist_pool( name (:class:`str`): Required. The name of the SpecialistPool resource. The form is - ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index 6018955006..cde21b3720 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -477,7 +477,6 @@ def get_specialist_pool( name (str): Required. The name of the SpecialistPool resource. The form is - ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index 61a5f5de57..976bcf55b8 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -15,7 +15,16 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 61c82508b9..dbc31f0c7e 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -253,8 +253,7 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py new file mode 100644 index 0000000000..5c312868f1 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import VizierServiceClient +from .async_client import VizierServiceAsyncClient + +__all__ = ( + "VizierServiceClient", + "VizierServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py new file mode 100644 index 0000000000..4bd90a79cd --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/async_client.py @@ -0,0 +1,1261 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.vizier_service import pagers +from google.cloud.aiplatform_v1beta1.types import study +from google.cloud.aiplatform_v1beta1.types import study as gca_study +from google.cloud.aiplatform_v1beta1.types import vizier_service +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import VizierServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import VizierServiceGrpcAsyncIOTransport +from .client import VizierServiceClient + + +class VizierServiceAsyncClient: + """Cloud AI Platform Vizier API. + Vizier service is a GCP service to solve blackbox optimization + problems, such as tuning machine learning hyperparameters and + searching over deep learning architectures. + """ + + _client: VizierServiceClient + + DEFAULT_ENDPOINT = VizierServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = VizierServiceClient.DEFAULT_MTLS_ENDPOINT + + custom_job_path = staticmethod(VizierServiceClient.custom_job_path) + parse_custom_job_path = staticmethod(VizierServiceClient.parse_custom_job_path) + study_path = staticmethod(VizierServiceClient.study_path) + parse_study_path = staticmethod(VizierServiceClient.parse_study_path) + trial_path = staticmethod(VizierServiceClient.trial_path) + parse_trial_path = staticmethod(VizierServiceClient.parse_trial_path) + + common_billing_account_path = staticmethod( + VizierServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + VizierServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(VizierServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + VizierServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + VizierServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + VizierServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(VizierServiceClient.common_project_path) + parse_common_project_path = staticmethod( + VizierServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(VizierServiceClient.common_location_path) + parse_common_location_path = staticmethod( + VizierServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + VizierServiceAsyncClient: The constructed client. + """ + return VizierServiceClient.from_service_account_info.__func__(VizierServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + VizierServiceAsyncClient: The constructed client. + """ + return VizierServiceClient.from_service_account_file.__func__(VizierServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> VizierServiceTransport: + """Return the transport used by the client instance. + + Returns: + VizierServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(VizierServiceClient).get_transport_class, type(VizierServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, VizierServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the vizier service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.VizierServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = VizierServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_study( + self, + request: vizier_service.CreateStudyRequest = None, + *, + parent: str = None, + study: gca_study.Study = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_study.Study: + r"""Creates a Study. A resource name will be generated + after creation of the Study. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateStudyRequest`): + The request object. Request message for + ``VizierService.CreateStudy``. + parent (:class:`str`): + Required. The resource name of the Location to create + the CustomJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + study (:class:`google.cloud.aiplatform_v1beta1.types.Study`): + Required. The Study configuration + used to create the Study. + + This corresponds to the ``study`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Study: + A message representing a Study. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, study]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.CreateStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if study is not None: + request.study = study + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_study, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_study( + self, + request: vizier_service.GetStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: + r"""Gets a Study by name. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetStudyRequest`): + The request object. Request message for + ``VizierService.GetStudy``. + name (:class:`str`): + Required. The name of the Study resource. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Study: + A message representing a Study. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.GetStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_study, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_studies( + self, + request: vizier_service.ListStudiesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListStudiesAsyncPager: + r"""Lists all the studies in a region for an associated + project. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListStudiesRequest`): + The request object. Request message for + ``VizierService.ListStudies``. + parent (:class:`str`): + Required. The resource name of the Location to list the + Study from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.vizier_service.pagers.ListStudiesAsyncPager: + Response message for + ``VizierService.ListStudies``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.ListStudiesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_studies, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListStudiesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_study( + self, + request: vizier_service.DeleteStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a Study. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteStudyRequest`): + The request object. Request message for + ``VizierService.DeleteStudy``. + name (:class:`str`): + Required. The name of the Study resource to be deleted. + Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.DeleteStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_study, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def lookup_study( + self, + request: vizier_service.LookupStudyRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: + r"""Looks a study up using the user-defined display_name field + instead of the fully qualified resource name. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.LookupStudyRequest`): + The request object. Request message for + ``VizierService.LookupStudy``. + parent (:class:`str`): + Required. The resource name of the Location to get the + Study from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Study: + A message representing a Study. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.LookupStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.lookup_study, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def suggest_trials( + self, + request: vizier_service.SuggestTrialsRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Adds one or more Trials to a Study, with parameter values + suggested by AI Platform Vizier. Returns a long-running + operation associated with the generation of Trial suggestions. + When this long-running operation succeeds, it will contain a + ``SuggestTrialsResponse``. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.SuggestTrialsRequest`): + The request object. Request message for + ``VizierService.SuggestTrials``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.SuggestTrialsResponse` + Response message for + ``VizierService.SuggestTrials``. + + """ + # Create or coerce a protobuf request object. + + request = vizier_service.SuggestTrialsRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.suggest_trials, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + vizier_service.SuggestTrialsResponse, + metadata_type=vizier_service.SuggestTrialsMetadata, + ) + + # Done; return the response. + return response + + async def create_trial( + self, + request: vizier_service.CreateTrialRequest = None, + *, + parent: str = None, + trial: study.Trial = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Adds a user provided Trial to a Study. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateTrialRequest`): + The request object. Request message for + ``VizierService.CreateTrial``. + parent (:class:`str`): + Required. The resource name of the Study to create the + Trial in. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + trial (:class:`google.cloud.aiplatform_v1beta1.types.Trial`): + Required. The Trial to create. + This corresponds to the ``trial`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, trial]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.CreateTrialRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if trial is not None: + request.trial = trial + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_trial, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_trial( + self, + request: vizier_service.GetTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Gets a Trial. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.GetTrialRequest`): + The request object. Request message for + ``VizierService.GetTrial``. + name (:class:`str`): + Required. The name of the Trial resource. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.GetTrialRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_trial, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_trials( + self, + request: vizier_service.ListTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrialsAsyncPager: + r"""Lists the Trials associated with a Study. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListTrialsRequest`): + The request object. Request message for + ``VizierService.ListTrials``. + parent (:class:`str`): + Required. The resource name of the Study to list the + Trial from. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.vizier_service.pagers.ListTrialsAsyncPager: + Response message for + ``VizierService.ListTrials``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.ListTrialsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_trials, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTrialsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def add_trial_measurement( + self, + request: vizier_service.AddTrialMeasurementRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Adds a measurement of the objective metrics to a + Trial. This measurement is assumed to have been taken + before the Trial is complete. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.AddTrialMeasurementRequest`): + The request object. Request message for + ``VizierService.AddTrialMeasurement``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + + request = vizier_service.AddTrialMeasurementRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.add_trial_measurement, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def complete_trial( + self, + request: vizier_service.CompleteTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Marks a Trial as complete. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CompleteTrialRequest`): + The request object. Request message for + ``VizierService.CompleteTrial``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + + request = vizier_service.CompleteTrialRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.complete_trial, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def delete_trial( + self, + request: vizier_service.DeleteTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a Trial. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteTrialRequest`): + The request object. Request message for + ``VizierService.DeleteTrial``. + name (:class:`str`): + Required. The Trial's name. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.DeleteTrialRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_trial, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def check_trial_early_stopping_state( + self, + request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Checks whether a Trial should stop or not. Returns a + long-running operation. When the operation is successful, it + will contain a + ``CheckTrialEarlyStoppingStateResponse``. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.CheckTrialEarlyStoppingStateRequest`): + The request object. Request message for + ``VizierService.CheckTrialEarlyStoppingState``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.CheckTrialEarlyStoppingStateResponse` + Response message for + ``VizierService.CheckTrialEarlyStoppingState``. + + """ + # Create or coerce a protobuf request object. + + request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.check_trial_early_stopping_state, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + vizier_service.CheckTrialEarlyStoppingStateResponse, + metadata_type=vizier_service.CheckTrialEarlyStoppingStateMetatdata, + ) + + # Done; return the response. + return response + + async def stop_trial( + self, + request: vizier_service.StopTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Stops a Trial. + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.StopTrialRequest`): + The request object. Request message for + ``VizierService.StopTrial``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + + request = vizier_service.StopTrialRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.stop_trial, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_optimal_trials( + self, + request: vizier_service.ListOptimalTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> vizier_service.ListOptimalTrialsResponse: + r"""Lists the pareto-optimal Trials for multi-objective Study or the + optimal Trials for single-objective Study. The definition of + pareto-optimal can be checked in wiki page. + https://en.wikipedia.org/wiki/Pareto_efficiency + + Args: + request (:class:`google.cloud.aiplatform_v1beta1.types.ListOptimalTrialsRequest`): + The request object. Request message for + ``VizierService.ListOptimalTrials``. + parent (:class:`str`): + Required. The name of the Study that + the optimal Trial belongs to. + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ListOptimalTrialsResponse: + Response message for + ``VizierService.ListOptimalTrials``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = vizier_service.ListOptimalTrialsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_optimal_trials, + default_timeout=5.0, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("VizierServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py new file mode 100644 index 0000000000..85e381323d --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/client.py @@ -0,0 +1,1478 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.vizier_service import pagers +from google.cloud.aiplatform_v1beta1.types import study +from google.cloud.aiplatform_v1beta1.types import study as gca_study +from google.cloud.aiplatform_v1beta1.types import vizier_service +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import VizierServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import VizierServiceGrpcTransport +from .transports.grpc_asyncio import VizierServiceGrpcAsyncIOTransport + + +class VizierServiceClientMeta(type): + """Metaclass for the VizierService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[VizierServiceTransport]] + _transport_registry["grpc"] = VizierServiceGrpcTransport + _transport_registry["grpc_asyncio"] = VizierServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[VizierServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class VizierServiceClient(metaclass=VizierServiceClientMeta): + """Cloud AI Platform Vizier API. + Vizier service is a GCP service to solve blackbox optimization + problems, such as tuning machine learning hyperparameters and + searching over deep learning architectures. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + VizierServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + VizierServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> VizierServiceTransport: + """Return the transport used by the client instance. + + Returns: + VizierServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def custom_job_path(project: str, location: str, custom_job: str,) -> str: + """Return a fully-qualified custom_job string.""" + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + + @staticmethod + def parse_custom_job_path(path: str) -> Dict[str, str]: + """Parse a custom_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def study_path(project: str, location: str, study: str,) -> str: + """Return a fully-qualified study string.""" + return "projects/{project}/locations/{location}/studies/{study}".format( + project=project, location=location, study=study, + ) + + @staticmethod + def parse_study_path(path: str) -> Dict[str, str]: + """Parse a study path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def trial_path(project: str, location: str, study: str, trial: str,) -> str: + """Return a fully-qualified trial string.""" + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) + + @staticmethod + def parse_trial_path(path: str) -> Dict[str, str]: + """Parse a trial path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, VizierServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the vizier service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, VizierServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, VizierServiceTransport): + # transport is a VizierServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_study( + self, + request: vizier_service.CreateStudyRequest = None, + *, + parent: str = None, + study: gca_study.Study = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_study.Study: + r"""Creates a Study. A resource name will be generated + after creation of the Study. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateStudyRequest): + The request object. Request message for + ``VizierService.CreateStudy``. + parent (str): + Required. The resource name of the Location to create + the CustomJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + study (google.cloud.aiplatform_v1beta1.types.Study): + Required. The Study configuration + used to create the Study. + + This corresponds to the ``study`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Study: + A message representing a Study. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, study]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.CreateStudyRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.CreateStudyRequest): + request = vizier_service.CreateStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if study is not None: + request.study = study + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_study] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_study( + self, + request: vizier_service.GetStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: + r"""Gets a Study by name. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetStudyRequest): + The request object. Request message for + ``VizierService.GetStudy``. + name (str): + Required. The name of the Study resource. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Study: + A message representing a Study. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.GetStudyRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.GetStudyRequest): + request = vizier_service.GetStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_study] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_studies( + self, + request: vizier_service.ListStudiesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListStudiesPager: + r"""Lists all the studies in a region for an associated + project. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListStudiesRequest): + The request object. Request message for + ``VizierService.ListStudies``. + parent (str): + Required. The resource name of the Location to list the + Study from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.vizier_service.pagers.ListStudiesPager: + Response message for + ``VizierService.ListStudies``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.ListStudiesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.ListStudiesRequest): + request = vizier_service.ListStudiesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_studies] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListStudiesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_study( + self, + request: vizier_service.DeleteStudyRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a Study. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteStudyRequest): + The request object. Request message for + ``VizierService.DeleteStudy``. + name (str): + Required. The name of the Study resource to be deleted. + Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.DeleteStudyRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.DeleteStudyRequest): + request = vizier_service.DeleteStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_study] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def lookup_study( + self, + request: vizier_service.LookupStudyRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Study: + r"""Looks a study up using the user-defined display_name field + instead of the fully qualified resource name. + + Args: + request (google.cloud.aiplatform_v1beta1.types.LookupStudyRequest): + The request object. Request message for + ``VizierService.LookupStudy``. + parent (str): + Required. The resource name of the Location to get the + Study from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Study: + A message representing a Study. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.LookupStudyRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.LookupStudyRequest): + request = vizier_service.LookupStudyRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.lookup_study] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def suggest_trials( + self, + request: vizier_service.SuggestTrialsRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Adds one or more Trials to a Study, with parameter values + suggested by AI Platform Vizier. Returns a long-running + operation associated with the generation of Trial suggestions. + When this long-running operation succeeds, it will contain a + ``SuggestTrialsResponse``. + + Args: + request (google.cloud.aiplatform_v1beta1.types.SuggestTrialsRequest): + The request object. Request message for + ``VizierService.SuggestTrials``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.SuggestTrialsResponse` + Response message for + ``VizierService.SuggestTrials``. + + """ + # Create or coerce a protobuf request object. + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.SuggestTrialsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.SuggestTrialsRequest): + request = vizier_service.SuggestTrialsRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.suggest_trials] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + vizier_service.SuggestTrialsResponse, + metadata_type=vizier_service.SuggestTrialsMetadata, + ) + + # Done; return the response. + return response + + def create_trial( + self, + request: vizier_service.CreateTrialRequest = None, + *, + parent: str = None, + trial: study.Trial = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Adds a user provided Trial to a Study. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CreateTrialRequest): + The request object. Request message for + ``VizierService.CreateTrial``. + parent (str): + Required. The resource name of the Study to create the + Trial in. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + trial (google.cloud.aiplatform_v1beta1.types.Trial): + Required. The Trial to create. + This corresponds to the ``trial`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, trial]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.CreateTrialRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.CreateTrialRequest): + request = vizier_service.CreateTrialRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if trial is not None: + request.trial = trial + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_trial] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_trial( + self, + request: vizier_service.GetTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Gets a Trial. + + Args: + request (google.cloud.aiplatform_v1beta1.types.GetTrialRequest): + The request object. Request message for + ``VizierService.GetTrial``. + name (str): + Required. The name of the Trial resource. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.GetTrialRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.GetTrialRequest): + request = vizier_service.GetTrialRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_trial] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_trials( + self, + request: vizier_service.ListTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrialsPager: + r"""Lists the Trials associated with a Study. + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListTrialsRequest): + The request object. Request message for + ``VizierService.ListTrials``. + parent (str): + Required. The resource name of the Study to list the + Trial from. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.services.vizier_service.pagers.ListTrialsPager: + Response message for + ``VizierService.ListTrials``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.ListTrialsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.ListTrialsRequest): + request = vizier_service.ListTrialsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_trials] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTrialsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def add_trial_measurement( + self, + request: vizier_service.AddTrialMeasurementRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Adds a measurement of the objective metrics to a + Trial. This measurement is assumed to have been taken + before the Trial is complete. + + Args: + request (google.cloud.aiplatform_v1beta1.types.AddTrialMeasurementRequest): + The request object. Request message for + ``VizierService.AddTrialMeasurement``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.AddTrialMeasurementRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.AddTrialMeasurementRequest): + request = vizier_service.AddTrialMeasurementRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.add_trial_measurement] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def complete_trial( + self, + request: vizier_service.CompleteTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Marks a Trial as complete. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CompleteTrialRequest): + The request object. Request message for + ``VizierService.CompleteTrial``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.CompleteTrialRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.CompleteTrialRequest): + request = vizier_service.CompleteTrialRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.complete_trial] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def delete_trial( + self, + request: vizier_service.DeleteTrialRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a Trial. + + Args: + request (google.cloud.aiplatform_v1beta1.types.DeleteTrialRequest): + The request object. Request message for + ``VizierService.DeleteTrial``. + name (str): + Required. The Trial's name. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.DeleteTrialRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.DeleteTrialRequest): + request = vizier_service.DeleteTrialRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_trial] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def check_trial_early_stopping_state( + self, + request: vizier_service.CheckTrialEarlyStoppingStateRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Checks whether a Trial should stop or not. Returns a + long-running operation. When the operation is successful, it + will contain a + ``CheckTrialEarlyStoppingStateResponse``. + + Args: + request (google.cloud.aiplatform_v1beta1.types.CheckTrialEarlyStoppingStateRequest): + The request object. Request message for + ``VizierService.CheckTrialEarlyStoppingState``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1beta1.types.CheckTrialEarlyStoppingStateResponse` + Response message for + ``VizierService.CheckTrialEarlyStoppingState``. + + """ + # Create or coerce a protobuf request object. + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.CheckTrialEarlyStoppingStateRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.CheckTrialEarlyStoppingStateRequest): + request = vizier_service.CheckTrialEarlyStoppingStateRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.check_trial_early_stopping_state + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("trial_name", request.trial_name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + vizier_service.CheckTrialEarlyStoppingStateResponse, + metadata_type=vizier_service.CheckTrialEarlyStoppingStateMetatdata, + ) + + # Done; return the response. + return response + + def stop_trial( + self, + request: vizier_service.StopTrialRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> study.Trial: + r"""Stops a Trial. + + Args: + request (google.cloud.aiplatform_v1beta1.types.StopTrialRequest): + The request object. Request message for + ``VizierService.StopTrial``. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.Trial: + A message representing a Trial. A + Trial contains a unique set of + Parameters that has been or will be + evaluated, along with the objective + metrics got by running the Trial. + + """ + # Create or coerce a protobuf request object. + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.StopTrialRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.StopTrialRequest): + request = vizier_service.StopTrialRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.stop_trial] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_optimal_trials( + self, + request: vizier_service.ListOptimalTrialsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> vizier_service.ListOptimalTrialsResponse: + r"""Lists the pareto-optimal Trials for multi-objective Study or the + optimal Trials for single-objective Study. The definition of + pareto-optimal can be checked in wiki page. + https://en.wikipedia.org/wiki/Pareto_efficiency + + Args: + request (google.cloud.aiplatform_v1beta1.types.ListOptimalTrialsRequest): + The request object. Request message for + ``VizierService.ListOptimalTrials``. + parent (str): + Required. The name of the Study that + the optimal Trial belongs to. + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1beta1.types.ListOptimalTrialsResponse: + Response message for + ``VizierService.ListOptimalTrials``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a vizier_service.ListOptimalTrialsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, vizier_service.ListOptimalTrialsRequest): + request = vizier_service.ListOptimalTrialsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_optimal_trials] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("VizierServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py new file mode 100644 index 0000000000..c6e4fcdf63 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/pagers.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.aiplatform_v1beta1.types import study +from google.cloud.aiplatform_v1beta1.types import vizier_service + + +class ListStudiesPager: + """A pager for iterating through ``list_studies`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListStudiesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``studies`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListStudies`` requests and continue to iterate + through the ``studies`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListStudiesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., vizier_service.ListStudiesResponse], + request: vizier_service.ListStudiesRequest, + response: vizier_service.ListStudiesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListStudiesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListStudiesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = vizier_service.ListStudiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[vizier_service.ListStudiesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[study.Study]: + for page in self.pages: + yield from page.studies + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListStudiesAsyncPager: + """A pager for iterating through ``list_studies`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListStudiesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``studies`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListStudies`` requests and continue to iterate + through the ``studies`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListStudiesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[vizier_service.ListStudiesResponse]], + request: vizier_service.ListStudiesRequest, + response: vizier_service.ListStudiesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListStudiesRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListStudiesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = vizier_service.ListStudiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[vizier_service.ListStudiesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[study.Study]: + async def async_generator(): + async for page in self.pages: + for response in page.studies: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTrialsPager: + """A pager for iterating through ``list_trials`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTrialsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``trials`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTrials`` requests and continue to iterate + through the ``trials`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTrialsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., vizier_service.ListTrialsResponse], + request: vizier_service.ListTrialsRequest, + response: vizier_service.ListTrialsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTrialsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTrialsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = vizier_service.ListTrialsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[vizier_service.ListTrialsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[study.Trial]: + for page in self.pages: + yield from page.trials + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTrialsAsyncPager: + """A pager for iterating through ``list_trials`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1beta1.types.ListTrialsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``trials`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTrials`` requests and continue to iterate + through the ``trials`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTrialsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[vizier_service.ListTrialsResponse]], + request: vizier_service.ListTrialsRequest, + response: vizier_service.ListTrialsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1beta1.types.ListTrialsRequest): + The initial request object. + response (google.cloud.aiplatform_v1beta1.types.ListTrialsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = vizier_service.ListTrialsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[vizier_service.ListTrialsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[study.Trial]: + async def async_generator(): + async for page in self.pages: + for response in page.trials: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py new file mode 100644 index 0000000000..3ed347a603 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import VizierServiceTransport +from .grpc import VizierServiceGrpcTransport +from .grpc_asyncio import VizierServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[VizierServiceTransport]] +_transport_registry["grpc"] = VizierServiceGrpcTransport +_transport_registry["grpc_asyncio"] = VizierServiceGrpcAsyncIOTransport + +__all__ = ( + "VizierServiceTransport", + "VizierServiceGrpcTransport", + "VizierServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py new file mode 100644 index 0000000000..2fdfb4b13f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/base.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import study +from google.cloud.aiplatform_v1beta1.types import study as gca_study +from google.cloud.aiplatform_v1beta1.types import vizier_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class VizierServiceTransport(abc.ABC): + """Abstract transport class for VizierService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_study: gapic_v1.method.wrap_method( + self.create_study, default_timeout=5.0, client_info=client_info, + ), + self.get_study: gapic_v1.method.wrap_method( + self.get_study, default_timeout=5.0, client_info=client_info, + ), + self.list_studies: gapic_v1.method.wrap_method( + self.list_studies, default_timeout=5.0, client_info=client_info, + ), + self.delete_study: gapic_v1.method.wrap_method( + self.delete_study, default_timeout=5.0, client_info=client_info, + ), + self.lookup_study: gapic_v1.method.wrap_method( + self.lookup_study, default_timeout=5.0, client_info=client_info, + ), + self.suggest_trials: gapic_v1.method.wrap_method( + self.suggest_trials, default_timeout=5.0, client_info=client_info, + ), + self.create_trial: gapic_v1.method.wrap_method( + self.create_trial, default_timeout=5.0, client_info=client_info, + ), + self.get_trial: gapic_v1.method.wrap_method( + self.get_trial, default_timeout=5.0, client_info=client_info, + ), + self.list_trials: gapic_v1.method.wrap_method( + self.list_trials, default_timeout=5.0, client_info=client_info, + ), + self.add_trial_measurement: gapic_v1.method.wrap_method( + self.add_trial_measurement, + default_timeout=5.0, + client_info=client_info, + ), + self.complete_trial: gapic_v1.method.wrap_method( + self.complete_trial, default_timeout=5.0, client_info=client_info, + ), + self.delete_trial: gapic_v1.method.wrap_method( + self.delete_trial, default_timeout=5.0, client_info=client_info, + ), + self.check_trial_early_stopping_state: gapic_v1.method.wrap_method( + self.check_trial_early_stopping_state, + default_timeout=5.0, + client_info=client_info, + ), + self.stop_trial: gapic_v1.method.wrap_method( + self.stop_trial, default_timeout=5.0, client_info=client_info, + ), + self.list_optimal_trials: gapic_v1.method.wrap_method( + self.list_optimal_trials, default_timeout=5.0, client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_study( + self, + ) -> typing.Callable[ + [vizier_service.CreateStudyRequest], + typing.Union[gca_study.Study, typing.Awaitable[gca_study.Study]], + ]: + raise NotImplementedError() + + @property + def get_study( + self, + ) -> typing.Callable[ + [vizier_service.GetStudyRequest], + typing.Union[study.Study, typing.Awaitable[study.Study]], + ]: + raise NotImplementedError() + + @property + def list_studies( + self, + ) -> typing.Callable[ + [vizier_service.ListStudiesRequest], + typing.Union[ + vizier_service.ListStudiesResponse, + typing.Awaitable[vizier_service.ListStudiesResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_study( + self, + ) -> typing.Callable[ + [vizier_service.DeleteStudyRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + + @property + def lookup_study( + self, + ) -> typing.Callable[ + [vizier_service.LookupStudyRequest], + typing.Union[study.Study, typing.Awaitable[study.Study]], + ]: + raise NotImplementedError() + + @property + def suggest_trials( + self, + ) -> typing.Callable[ + [vizier_service.SuggestTrialsRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def create_trial( + self, + ) -> typing.Callable[ + [vizier_service.CreateTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: + raise NotImplementedError() + + @property + def get_trial( + self, + ) -> typing.Callable[ + [vizier_service.GetTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: + raise NotImplementedError() + + @property + def list_trials( + self, + ) -> typing.Callable[ + [vizier_service.ListTrialsRequest], + typing.Union[ + vizier_service.ListTrialsResponse, + typing.Awaitable[vizier_service.ListTrialsResponse], + ], + ]: + raise NotImplementedError() + + @property + def add_trial_measurement( + self, + ) -> typing.Callable[ + [vizier_service.AddTrialMeasurementRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: + raise NotImplementedError() + + @property + def complete_trial( + self, + ) -> typing.Callable[ + [vizier_service.CompleteTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: + raise NotImplementedError() + + @property + def delete_trial( + self, + ) -> typing.Callable[ + [vizier_service.DeleteTrialRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + + @property + def check_trial_early_stopping_state( + self, + ) -> typing.Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def stop_trial( + self, + ) -> typing.Callable[ + [vizier_service.StopTrialRequest], + typing.Union[study.Trial, typing.Awaitable[study.Trial]], + ]: + raise NotImplementedError() + + @property + def list_optimal_trials( + self, + ) -> typing.Callable[ + [vizier_service.ListOptimalTrialsRequest], + typing.Union[ + vizier_service.ListOptimalTrialsResponse, + typing.Awaitable[vizier_service.ListOptimalTrialsResponse], + ], + ]: + raise NotImplementedError() + + +__all__ = ("VizierServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py new file mode 100644 index 0000000000..388d2746f5 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc.py @@ -0,0 +1,685 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import study +from google.cloud.aiplatform_v1beta1.types import study as gca_study +from google.cloud.aiplatform_v1beta1.types import vizier_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import VizierServiceTransport, DEFAULT_CLIENT_INFO + + +class VizierServiceGrpcTransport(VizierServiceTransport): + """gRPC backend transport for VizierService. + + Cloud AI Platform Vizier API. + Vizier service is a GCP service to solve blackbox optimization + problems, such as tuning machine learning hyperparameters and + searching over deep learning architectures. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_study( + self, + ) -> Callable[[vizier_service.CreateStudyRequest], gca_study.Study]: + r"""Return a callable for the create study method over gRPC. + + Creates a Study. A resource name will be generated + after creation of the Study. + + Returns: + Callable[[~.CreateStudyRequest], + ~.Study]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_study" not in self._stubs: + self._stubs["create_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy", + request_serializer=vizier_service.CreateStudyRequest.serialize, + response_deserializer=gca_study.Study.deserialize, + ) + return self._stubs["create_study"] + + @property + def get_study(self) -> Callable[[vizier_service.GetStudyRequest], study.Study]: + r"""Return a callable for the get study method over gRPC. + + Gets a Study by name. + + Returns: + Callable[[~.GetStudyRequest], + ~.Study]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_study" not in self._stubs: + self._stubs["get_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetStudy", + request_serializer=vizier_service.GetStudyRequest.serialize, + response_deserializer=study.Study.deserialize, + ) + return self._stubs["get_study"] + + @property + def list_studies( + self, + ) -> Callable[ + [vizier_service.ListStudiesRequest], vizier_service.ListStudiesResponse + ]: + r"""Return a callable for the list studies method over gRPC. + + Lists all the studies in a region for an associated + project. + + Returns: + Callable[[~.ListStudiesRequest], + ~.ListStudiesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_studies" not in self._stubs: + self._stubs["list_studies"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListStudies", + request_serializer=vizier_service.ListStudiesRequest.serialize, + response_deserializer=vizier_service.ListStudiesResponse.deserialize, + ) + return self._stubs["list_studies"] + + @property + def delete_study( + self, + ) -> Callable[[vizier_service.DeleteStudyRequest], empty.Empty]: + r"""Return a callable for the delete study method over gRPC. + + Deletes a Study. + + Returns: + Callable[[~.DeleteStudyRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_study" not in self._stubs: + self._stubs["delete_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy", + request_serializer=vizier_service.DeleteStudyRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["delete_study"] + + @property + def lookup_study( + self, + ) -> Callable[[vizier_service.LookupStudyRequest], study.Study]: + r"""Return a callable for the lookup study method over gRPC. + + Looks a study up using the user-defined display_name field + instead of the fully qualified resource name. + + Returns: + Callable[[~.LookupStudyRequest], + ~.Study]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "lookup_study" not in self._stubs: + self._stubs["lookup_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy", + request_serializer=vizier_service.LookupStudyRequest.serialize, + response_deserializer=study.Study.deserialize, + ) + return self._stubs["lookup_study"] + + @property + def suggest_trials( + self, + ) -> Callable[[vizier_service.SuggestTrialsRequest], operations.Operation]: + r"""Return a callable for the suggest trials method over gRPC. + + Adds one or more Trials to a Study, with parameter values + suggested by AI Platform Vizier. Returns a long-running + operation associated with the generation of Trial suggestions. + When this long-running operation succeeds, it will contain a + ``SuggestTrialsResponse``. + + Returns: + Callable[[~.SuggestTrialsRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "suggest_trials" not in self._stubs: + self._stubs["suggest_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials", + request_serializer=vizier_service.SuggestTrialsRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["suggest_trials"] + + @property + def create_trial( + self, + ) -> Callable[[vizier_service.CreateTrialRequest], study.Trial]: + r"""Return a callable for the create trial method over gRPC. + + Adds a user provided Trial to a Study. + + Returns: + Callable[[~.CreateTrialRequest], + ~.Trial]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_trial" not in self._stubs: + self._stubs["create_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial", + request_serializer=vizier_service.CreateTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["create_trial"] + + @property + def get_trial(self) -> Callable[[vizier_service.GetTrialRequest], study.Trial]: + r"""Return a callable for the get trial method over gRPC. + + Gets a Trial. + + Returns: + Callable[[~.GetTrialRequest], + ~.Trial]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_trial" not in self._stubs: + self._stubs["get_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetTrial", + request_serializer=vizier_service.GetTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["get_trial"] + + @property + def list_trials( + self, + ) -> Callable[ + [vizier_service.ListTrialsRequest], vizier_service.ListTrialsResponse + ]: + r"""Return a callable for the list trials method over gRPC. + + Lists the Trials associated with a Study. + + Returns: + Callable[[~.ListTrialsRequest], + ~.ListTrialsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_trials" not in self._stubs: + self._stubs["list_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListTrials", + request_serializer=vizier_service.ListTrialsRequest.serialize, + response_deserializer=vizier_service.ListTrialsResponse.deserialize, + ) + return self._stubs["list_trials"] + + @property + def add_trial_measurement( + self, + ) -> Callable[[vizier_service.AddTrialMeasurementRequest], study.Trial]: + r"""Return a callable for the add trial measurement method over gRPC. + + Adds a measurement of the objective metrics to a + Trial. This measurement is assumed to have been taken + before the Trial is complete. + + Returns: + Callable[[~.AddTrialMeasurementRequest], + ~.Trial]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "add_trial_measurement" not in self._stubs: + self._stubs["add_trial_measurement"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement", + request_serializer=vizier_service.AddTrialMeasurementRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["add_trial_measurement"] + + @property + def complete_trial( + self, + ) -> Callable[[vizier_service.CompleteTrialRequest], study.Trial]: + r"""Return a callable for the complete trial method over gRPC. + + Marks a Trial as complete. + + Returns: + Callable[[~.CompleteTrialRequest], + ~.Trial]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "complete_trial" not in self._stubs: + self._stubs["complete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial", + request_serializer=vizier_service.CompleteTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["complete_trial"] + + @property + def delete_trial( + self, + ) -> Callable[[vizier_service.DeleteTrialRequest], empty.Empty]: + r"""Return a callable for the delete trial method over gRPC. + + Deletes a Trial. + + Returns: + Callable[[~.DeleteTrialRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_trial" not in self._stubs: + self._stubs["delete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial", + request_serializer=vizier_service.DeleteTrialRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["delete_trial"] + + @property + def check_trial_early_stopping_state( + self, + ) -> Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], operations.Operation + ]: + r"""Return a callable for the check trial early stopping + state method over gRPC. + + Checks whether a Trial should stop or not. Returns a + long-running operation. When the operation is successful, it + will contain a + ``CheckTrialEarlyStoppingStateResponse``. + + Returns: + Callable[[~.CheckTrialEarlyStoppingStateRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "check_trial_early_stopping_state" not in self._stubs: + self._stubs[ + "check_trial_early_stopping_state" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState", + request_serializer=vizier_service.CheckTrialEarlyStoppingStateRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["check_trial_early_stopping_state"] + + @property + def stop_trial(self) -> Callable[[vizier_service.StopTrialRequest], study.Trial]: + r"""Return a callable for the stop trial method over gRPC. + + Stops a Trial. + + Returns: + Callable[[~.StopTrialRequest], + ~.Trial]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "stop_trial" not in self._stubs: + self._stubs["stop_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/StopTrial", + request_serializer=vizier_service.StopTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["stop_trial"] + + @property + def list_optimal_trials( + self, + ) -> Callable[ + [vizier_service.ListOptimalTrialsRequest], + vizier_service.ListOptimalTrialsResponse, + ]: + r"""Return a callable for the list optimal trials method over gRPC. + + Lists the pareto-optimal Trials for multi-objective Study or the + optimal Trials for single-objective Study. The definition of + pareto-optimal can be checked in wiki page. + https://en.wikipedia.org/wiki/Pareto_efficiency + + Returns: + Callable[[~.ListOptimalTrialsRequest], + ~.ListOptimalTrialsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_optimal_trials" not in self._stubs: + self._stubs["list_optimal_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials", + request_serializer=vizier_service.ListOptimalTrialsRequest.serialize, + response_deserializer=vizier_service.ListOptimalTrialsResponse.deserialize, + ) + return self._stubs["list_optimal_trials"] + + +__all__ = ("VizierServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..82e28342a4 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/vizier_service/transports/grpc_asyncio.py @@ -0,0 +1,702 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import study +from google.cloud.aiplatform_v1beta1.types import study as gca_study +from google.cloud.aiplatform_v1beta1.types import vizier_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import VizierServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import VizierServiceGrpcTransport + + +class VizierServiceGrpcAsyncIOTransport(VizierServiceTransport): + """gRPC AsyncIO backend transport for VizierService. + + Cloud AI Platform Vizier API. + Vizier service is a GCP service to solve blackbox optimization + problems, such as tuning machine learning hyperparameters and + searching over deep learning architectures. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_study( + self, + ) -> Callable[[vizier_service.CreateStudyRequest], Awaitable[gca_study.Study]]: + r"""Return a callable for the create study method over gRPC. + + Creates a Study. A resource name will be generated + after creation of the Study. + + Returns: + Callable[[~.CreateStudyRequest], + Awaitable[~.Study]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_study" not in self._stubs: + self._stubs["create_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateStudy", + request_serializer=vizier_service.CreateStudyRequest.serialize, + response_deserializer=gca_study.Study.deserialize, + ) + return self._stubs["create_study"] + + @property + def get_study( + self, + ) -> Callable[[vizier_service.GetStudyRequest], Awaitable[study.Study]]: + r"""Return a callable for the get study method over gRPC. + + Gets a Study by name. + + Returns: + Callable[[~.GetStudyRequest], + Awaitable[~.Study]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_study" not in self._stubs: + self._stubs["get_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetStudy", + request_serializer=vizier_service.GetStudyRequest.serialize, + response_deserializer=study.Study.deserialize, + ) + return self._stubs["get_study"] + + @property + def list_studies( + self, + ) -> Callable[ + [vizier_service.ListStudiesRequest], + Awaitable[vizier_service.ListStudiesResponse], + ]: + r"""Return a callable for the list studies method over gRPC. + + Lists all the studies in a region for an associated + project. + + Returns: + Callable[[~.ListStudiesRequest], + Awaitable[~.ListStudiesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_studies" not in self._stubs: + self._stubs["list_studies"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListStudies", + request_serializer=vizier_service.ListStudiesRequest.serialize, + response_deserializer=vizier_service.ListStudiesResponse.deserialize, + ) + return self._stubs["list_studies"] + + @property + def delete_study( + self, + ) -> Callable[[vizier_service.DeleteStudyRequest], Awaitable[empty.Empty]]: + r"""Return a callable for the delete study method over gRPC. + + Deletes a Study. + + Returns: + Callable[[~.DeleteStudyRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_study" not in self._stubs: + self._stubs["delete_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteStudy", + request_serializer=vizier_service.DeleteStudyRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["delete_study"] + + @property + def lookup_study( + self, + ) -> Callable[[vizier_service.LookupStudyRequest], Awaitable[study.Study]]: + r"""Return a callable for the lookup study method over gRPC. + + Looks a study up using the user-defined display_name field + instead of the fully qualified resource name. + + Returns: + Callable[[~.LookupStudyRequest], + Awaitable[~.Study]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "lookup_study" not in self._stubs: + self._stubs["lookup_study"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/LookupStudy", + request_serializer=vizier_service.LookupStudyRequest.serialize, + response_deserializer=study.Study.deserialize, + ) + return self._stubs["lookup_study"] + + @property + def suggest_trials( + self, + ) -> Callable[ + [vizier_service.SuggestTrialsRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the suggest trials method over gRPC. + + Adds one or more Trials to a Study, with parameter values + suggested by AI Platform Vizier. Returns a long-running + operation associated with the generation of Trial suggestions. + When this long-running operation succeeds, it will contain a + ``SuggestTrialsResponse``. + + Returns: + Callable[[~.SuggestTrialsRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "suggest_trials" not in self._stubs: + self._stubs["suggest_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/SuggestTrials", + request_serializer=vizier_service.SuggestTrialsRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["suggest_trials"] + + @property + def create_trial( + self, + ) -> Callable[[vizier_service.CreateTrialRequest], Awaitable[study.Trial]]: + r"""Return a callable for the create trial method over gRPC. + + Adds a user provided Trial to a Study. + + Returns: + Callable[[~.CreateTrialRequest], + Awaitable[~.Trial]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_trial" not in self._stubs: + self._stubs["create_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CreateTrial", + request_serializer=vizier_service.CreateTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["create_trial"] + + @property + def get_trial( + self, + ) -> Callable[[vizier_service.GetTrialRequest], Awaitable[study.Trial]]: + r"""Return a callable for the get trial method over gRPC. + + Gets a Trial. + + Returns: + Callable[[~.GetTrialRequest], + Awaitable[~.Trial]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_trial" not in self._stubs: + self._stubs["get_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/GetTrial", + request_serializer=vizier_service.GetTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["get_trial"] + + @property + def list_trials( + self, + ) -> Callable[ + [vizier_service.ListTrialsRequest], Awaitable[vizier_service.ListTrialsResponse] + ]: + r"""Return a callable for the list trials method over gRPC. + + Lists the Trials associated with a Study. + + Returns: + Callable[[~.ListTrialsRequest], + Awaitable[~.ListTrialsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_trials" not in self._stubs: + self._stubs["list_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListTrials", + request_serializer=vizier_service.ListTrialsRequest.serialize, + response_deserializer=vizier_service.ListTrialsResponse.deserialize, + ) + return self._stubs["list_trials"] + + @property + def add_trial_measurement( + self, + ) -> Callable[[vizier_service.AddTrialMeasurementRequest], Awaitable[study.Trial]]: + r"""Return a callable for the add trial measurement method over gRPC. + + Adds a measurement of the objective metrics to a + Trial. This measurement is assumed to have been taken + before the Trial is complete. + + Returns: + Callable[[~.AddTrialMeasurementRequest], + Awaitable[~.Trial]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "add_trial_measurement" not in self._stubs: + self._stubs["add_trial_measurement"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/AddTrialMeasurement", + request_serializer=vizier_service.AddTrialMeasurementRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["add_trial_measurement"] + + @property + def complete_trial( + self, + ) -> Callable[[vizier_service.CompleteTrialRequest], Awaitable[study.Trial]]: + r"""Return a callable for the complete trial method over gRPC. + + Marks a Trial as complete. + + Returns: + Callable[[~.CompleteTrialRequest], + Awaitable[~.Trial]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "complete_trial" not in self._stubs: + self._stubs["complete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CompleteTrial", + request_serializer=vizier_service.CompleteTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["complete_trial"] + + @property + def delete_trial( + self, + ) -> Callable[[vizier_service.DeleteTrialRequest], Awaitable[empty.Empty]]: + r"""Return a callable for the delete trial method over gRPC. + + Deletes a Trial. + + Returns: + Callable[[~.DeleteTrialRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_trial" not in self._stubs: + self._stubs["delete_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/DeleteTrial", + request_serializer=vizier_service.DeleteTrialRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["delete_trial"] + + @property + def check_trial_early_stopping_state( + self, + ) -> Callable[ + [vizier_service.CheckTrialEarlyStoppingStateRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the check trial early stopping + state method over gRPC. + + Checks whether a Trial should stop or not. Returns a + long-running operation. When the operation is successful, it + will contain a + ``CheckTrialEarlyStoppingStateResponse``. + + Returns: + Callable[[~.CheckTrialEarlyStoppingStateRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "check_trial_early_stopping_state" not in self._stubs: + self._stubs[ + "check_trial_early_stopping_state" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/CheckTrialEarlyStoppingState", + request_serializer=vizier_service.CheckTrialEarlyStoppingStateRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["check_trial_early_stopping_state"] + + @property + def stop_trial( + self, + ) -> Callable[[vizier_service.StopTrialRequest], Awaitable[study.Trial]]: + r"""Return a callable for the stop trial method over gRPC. + + Stops a Trial. + + Returns: + Callable[[~.StopTrialRequest], + Awaitable[~.Trial]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "stop_trial" not in self._stubs: + self._stubs["stop_trial"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/StopTrial", + request_serializer=vizier_service.StopTrialRequest.serialize, + response_deserializer=study.Trial.deserialize, + ) + return self._stubs["stop_trial"] + + @property + def list_optimal_trials( + self, + ) -> Callable[ + [vizier_service.ListOptimalTrialsRequest], + Awaitable[vizier_service.ListOptimalTrialsResponse], + ]: + r"""Return a callable for the list optimal trials method over gRPC. + + Lists the pareto-optimal Trials for multi-objective Study or the + optimal Trials for single-objective Study. The definition of + pareto-optimal can be checked in wiki page. + https://en.wikipedia.org/wiki/Pareto_efficiency + + Returns: + Callable[[~.ListOptimalTrialsRequest], + Awaitable[~.ListOptimalTrialsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_optimal_trials" not in self._stubs: + self._stubs["list_optimal_trials"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.VizierService/ListOptimalTrials", + request_serializer=vizier_service.ListOptimalTrialsRequest.serialize, + response_deserializer=vizier_service.ListOptimalTrialsResponse.deserialize, + ) + return self._stubs["list_optimal_trials"] + + +__all__ = ("VizierServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index ca848c7c54..2d2368df8c 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -15,379 +15,429 @@ # limitations under the License. # -from .user_action_reference import UserActionReference from .annotation import Annotation from .annotation_spec import AnnotationSpec -from .completion_stats import CompletionStats -from .encryption_spec import EncryptionSpec -from .explanation_metadata import ExplanationMetadata -from .explanation import ( - Explanation, - ModelExplanation, - Attribution, - ExplanationSpec, - ExplanationParameters, - SampledShapleyAttribution, - IntegratedGradientsAttribution, - XraiAttribution, - SmoothGradConfig, - FeatureNoiseSigma, - ExplanationSpecOverride, - ExplanationMetadataOverride, -) -from .io import ( - GcsSource, - GcsDestination, - BigQuerySource, - BigQueryDestination, - ContainerRegistryDestination, -) -from .machine_resources import ( - MachineSpec, - DedicatedResources, - AutomaticResources, - BatchDedicatedResources, - ResourcesConsumed, - DiskSpec, -) -from .manual_batch_tuning_parameters import ManualBatchTuningParameters from .batch_prediction_job import BatchPredictionJob -from .env_var import EnvVar +from .completion_stats import CompletionStats from .custom_job import ( + ContainerSpec, CustomJob, CustomJobSpec, - WorkerPoolSpec, - ContainerSpec, PythonPackageSpec, Scheduling, + WorkerPoolSpec, ) from .data_item import DataItem -from .specialist_pool import SpecialistPool from .data_labeling_job import ( - DataLabelingJob, ActiveLearningConfig, + DataLabelingJob, SampleConfig, TrainingConfig, ) from .dataset import ( Dataset, - ImportDataConfig, ExportDataConfig, -) -from .operation import ( - GenericOperationMetadata, - DeleteOperationMetadata, -) -from .deployed_model_ref import DeployedModelRef -from .model import ( - Model, - PredictSchemata, - ModelContainerSpec, - Port, -) -from .training_pipeline import ( - TrainingPipeline, - InputDataConfig, - FractionSplit, - FilterSplit, - PredefinedSplit, - TimestampSplit, + ImportDataConfig, ) from .dataset_service import ( - CreateDatasetRequest, CreateDatasetOperationMetadata, - GetDatasetRequest, - UpdateDatasetRequest, - ListDatasetsRequest, - ListDatasetsResponse, + CreateDatasetRequest, DeleteDatasetRequest, - ImportDataRequest, - ImportDataResponse, - ImportDataOperationMetadata, + ExportDataOperationMetadata, ExportDataRequest, ExportDataResponse, - ExportDataOperationMetadata, - ListDataItemsRequest, - ListDataItemsResponse, GetAnnotationSpecRequest, + GetDatasetRequest, + ImportDataOperationMetadata, + ImportDataRequest, + ImportDataResponse, ListAnnotationsRequest, ListAnnotationsResponse, + ListDataItemsRequest, + ListDataItemsResponse, + ListDatasetsRequest, + ListDatasetsResponse, + UpdateDatasetRequest, ) +from .deployed_model_ref import DeployedModelRef +from .encryption_spec import EncryptionSpec from .endpoint import ( - Endpoint, DeployedModel, + Endpoint, ) from .endpoint_service import ( - CreateEndpointRequest, CreateEndpointOperationMetadata, - GetEndpointRequest, - ListEndpointsRequest, - ListEndpointsResponse, - UpdateEndpointRequest, + CreateEndpointRequest, DeleteEndpointRequest, + DeployModelOperationMetadata, DeployModelRequest, DeployModelResponse, - DeployModelOperationMetadata, + GetEndpointRequest, + ListEndpointsRequest, + ListEndpointsResponse, + UndeployModelOperationMetadata, UndeployModelRequest, UndeployModelResponse, - UndeployModelOperationMetadata, + UpdateEndpointRequest, ) -from .study import ( - Trial, - StudySpec, - Measurement, +from .env_var import EnvVar +from .explanation import ( + Attribution, + Explanation, + ExplanationMetadataOverride, + ExplanationParameters, + ExplanationSpec, + ExplanationSpecOverride, + FeatureNoiseSigma, + IntegratedGradientsAttribution, + ModelExplanation, + SampledShapleyAttribution, + SmoothGradConfig, + XraiAttribution, ) +from .explanation_metadata import ExplanationMetadata from .hyperparameter_tuning_job import HyperparameterTuningJob +from .io import ( + BigQueryDestination, + BigQuerySource, + ContainerRegistryDestination, + GcsDestination, + GcsSource, +) from .job_service import ( + CancelBatchPredictionJobRequest, + CancelCustomJobRequest, + CancelDataLabelingJobRequest, + CancelHyperparameterTuningJobRequest, + CreateBatchPredictionJobRequest, CreateCustomJobRequest, + CreateDataLabelingJobRequest, + CreateHyperparameterTuningJobRequest, + DeleteBatchPredictionJobRequest, + DeleteCustomJobRequest, + DeleteDataLabelingJobRequest, + DeleteHyperparameterTuningJobRequest, + GetBatchPredictionJobRequest, GetCustomJobRequest, + GetDataLabelingJobRequest, + GetHyperparameterTuningJobRequest, + ListBatchPredictionJobsRequest, + ListBatchPredictionJobsResponse, ListCustomJobsRequest, ListCustomJobsResponse, - DeleteCustomJobRequest, - CancelCustomJobRequest, - CreateDataLabelingJobRequest, - GetDataLabelingJobRequest, ListDataLabelingJobsRequest, ListDataLabelingJobsResponse, - DeleteDataLabelingJobRequest, - CancelDataLabelingJobRequest, - CreateHyperparameterTuningJobRequest, - GetHyperparameterTuningJobRequest, ListHyperparameterTuningJobsRequest, ListHyperparameterTuningJobsResponse, - DeleteHyperparameterTuningJobRequest, - CancelHyperparameterTuningJobRequest, - CreateBatchPredictionJobRequest, - GetBatchPredictionJobRequest, - ListBatchPredictionJobsRequest, - ListBatchPredictionJobsResponse, - DeleteBatchPredictionJobRequest, - CancelBatchPredictionJobRequest, ) +from .machine_resources import ( + AutomaticResources, + AutoscalingMetricSpec, + BatchDedicatedResources, + DedicatedResources, + DiskSpec, + MachineSpec, + ResourcesConsumed, +) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters from .migratable_resource import MigratableResource from .migration_service import ( - SearchMigratableResourcesRequest, - SearchMigratableResourcesResponse, + BatchMigrateResourcesOperationMetadata, BatchMigrateResourcesRequest, - MigrateResourceRequest, BatchMigrateResourcesResponse, + MigrateResourceRequest, MigrateResourceResponse, - BatchMigrateResourcesOperationMetadata, + SearchMigratableResourcesRequest, + SearchMigratableResourcesResponse, +) +from .model import ( + Model, + ModelContainerSpec, + Port, + PredictSchemata, ) from .model_evaluation import ModelEvaluation from .model_evaluation_slice import ModelEvaluationSlice from .model_service import ( - UploadModelRequest, - UploadModelOperationMetadata, - UploadModelResponse, - GetModelRequest, - ListModelsRequest, - ListModelsResponse, - UpdateModelRequest, DeleteModelRequest, - ExportModelRequest, ExportModelOperationMetadata, + ExportModelRequest, ExportModelResponse, GetModelEvaluationRequest, - ListModelEvaluationsRequest, - ListModelEvaluationsResponse, GetModelEvaluationSliceRequest, + GetModelRequest, ListModelEvaluationSlicesRequest, ListModelEvaluationSlicesResponse, + ListModelEvaluationsRequest, + ListModelEvaluationsResponse, + ListModelsRequest, + ListModelsResponse, + UpdateModelRequest, + UploadModelOperationMetadata, + UploadModelRequest, + UploadModelResponse, +) +from .operation import ( + DeleteOperationMetadata, + GenericOperationMetadata, ) from .pipeline_service import ( + CancelTrainingPipelineRequest, CreateTrainingPipelineRequest, + DeleteTrainingPipelineRequest, GetTrainingPipelineRequest, ListTrainingPipelinesRequest, ListTrainingPipelinesResponse, - DeleteTrainingPipelineRequest, - CancelTrainingPipelineRequest, ) from .prediction_service import ( - PredictRequest, - PredictResponse, ExplainRequest, ExplainResponse, + PredictRequest, + PredictResponse, ) +from .specialist_pool import SpecialistPool from .specialist_pool_service import ( - CreateSpecialistPoolRequest, CreateSpecialistPoolOperationMetadata, + CreateSpecialistPoolRequest, + DeleteSpecialistPoolRequest, GetSpecialistPoolRequest, ListSpecialistPoolsRequest, ListSpecialistPoolsResponse, - DeleteSpecialistPoolRequest, - UpdateSpecialistPoolRequest, UpdateSpecialistPoolOperationMetadata, + UpdateSpecialistPoolRequest, +) +from .study import ( + Measurement, + Study, + StudySpec, + Trial, +) +from .training_pipeline import ( + FilterSplit, + FractionSplit, + InputDataConfig, + PredefinedSplit, + TimestampSplit, + TrainingPipeline, +) +from .user_action_reference import UserActionReference +from .vizier_service import ( + AddTrialMeasurementRequest, + CheckTrialEarlyStoppingStateMetatdata, + CheckTrialEarlyStoppingStateRequest, + CheckTrialEarlyStoppingStateResponse, + CompleteTrialRequest, + CreateStudyRequest, + CreateTrialRequest, + DeleteStudyRequest, + DeleteTrialRequest, + GetStudyRequest, + GetTrialRequest, + ListOptimalTrialsRequest, + ListOptimalTrialsResponse, + ListStudiesRequest, + ListStudiesResponse, + ListTrialsRequest, + ListTrialsResponse, + LookupStudyRequest, + StopTrialRequest, + SuggestTrialsMetadata, + SuggestTrialsRequest, + SuggestTrialsResponse, ) __all__ = ( "AcceleratorType", - "UserActionReference", "Annotation", "AnnotationSpec", - "CompletionStats", - "EncryptionSpec", - "ExplanationMetadata", - "Explanation", - "ModelExplanation", - "Attribution", - "ExplanationSpec", - "ExplanationParameters", - "SampledShapleyAttribution", - "IntegratedGradientsAttribution", - "XraiAttribution", - "SmoothGradConfig", - "FeatureNoiseSigma", - "ExplanationSpecOverride", - "ExplanationMetadataOverride", - "GcsSource", - "GcsDestination", - "BigQuerySource", - "BigQueryDestination", - "ContainerRegistryDestination", - "JobState", - "MachineSpec", - "DedicatedResources", - "AutomaticResources", - "BatchDedicatedResources", - "ResourcesConsumed", - "DiskSpec", - "ManualBatchTuningParameters", "BatchPredictionJob", - "EnvVar", + "CompletionStats", + "ContainerSpec", "CustomJob", "CustomJobSpec", - "WorkerPoolSpec", - "ContainerSpec", "PythonPackageSpec", "Scheduling", + "WorkerPoolSpec", "DataItem", - "SpecialistPool", - "DataLabelingJob", "ActiveLearningConfig", + "DataLabelingJob", "SampleConfig", "TrainingConfig", "Dataset", - "ImportDataConfig", "ExportDataConfig", - "GenericOperationMetadata", - "DeleteOperationMetadata", - "DeployedModelRef", - "Model", - "PredictSchemata", - "ModelContainerSpec", - "Port", - "PipelineState", - "TrainingPipeline", - "InputDataConfig", - "FractionSplit", - "FilterSplit", - "PredefinedSplit", - "TimestampSplit", - "CreateDatasetRequest", + "ImportDataConfig", "CreateDatasetOperationMetadata", - "GetDatasetRequest", - "UpdateDatasetRequest", - "ListDatasetsRequest", - "ListDatasetsResponse", + "CreateDatasetRequest", "DeleteDatasetRequest", - "ImportDataRequest", - "ImportDataResponse", - "ImportDataOperationMetadata", + "ExportDataOperationMetadata", "ExportDataRequest", "ExportDataResponse", - "ExportDataOperationMetadata", - "ListDataItemsRequest", - "ListDataItemsResponse", "GetAnnotationSpecRequest", + "GetDatasetRequest", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", "ListAnnotationsRequest", "ListAnnotationsResponse", - "Endpoint", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "UpdateDatasetRequest", + "DeployedModelRef", + "EncryptionSpec", "DeployedModel", - "CreateEndpointRequest", + "Endpoint", "CreateEndpointOperationMetadata", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UpdateEndpointRequest", + "CreateEndpointRequest", "DeleteEndpointRequest", + "DeployModelOperationMetadata", "DeployModelRequest", "DeployModelResponse", - "DeployModelOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UndeployModelOperationMetadata", "UndeployModelRequest", "UndeployModelResponse", - "UndeployModelOperationMetadata", - "Trial", - "StudySpec", - "Measurement", + "UpdateEndpointRequest", + "EnvVar", + "Attribution", + "Explanation", + "ExplanationMetadataOverride", + "ExplanationParameters", + "ExplanationSpec", + "ExplanationSpecOverride", + "FeatureNoiseSigma", + "IntegratedGradientsAttribution", + "ModelExplanation", + "SampledShapleyAttribution", + "SmoothGradConfig", + "XraiAttribution", + "ExplanationMetadata", "HyperparameterTuningJob", + "BigQueryDestination", + "BigQuerySource", + "ContainerRegistryDestination", + "GcsDestination", + "GcsSource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteHyperparameterTuningJobRequest", + "GetBatchPredictionJobRequest", "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetHyperparameterTuningJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", "ListCustomJobsRequest", "ListCustomJobsResponse", - "DeleteCustomJobRequest", - "CancelCustomJobRequest", - "CreateDataLabelingJobRequest", - "GetDataLabelingJobRequest", "ListDataLabelingJobsRequest", "ListDataLabelingJobsResponse", - "DeleteDataLabelingJobRequest", - "CancelDataLabelingJobRequest", - "CreateHyperparameterTuningJobRequest", - "GetHyperparameterTuningJobRequest", "ListHyperparameterTuningJobsRequest", "ListHyperparameterTuningJobsResponse", - "DeleteHyperparameterTuningJobRequest", - "CancelHyperparameterTuningJobRequest", - "CreateBatchPredictionJobRequest", - "GetBatchPredictionJobRequest", - "ListBatchPredictionJobsRequest", - "ListBatchPredictionJobsResponse", - "DeleteBatchPredictionJobRequest", - "CancelBatchPredictionJobRequest", + "JobState", + "AutomaticResources", + "AutoscalingMetricSpec", + "BatchDedicatedResources", + "DedicatedResources", + "DiskSpec", + "MachineSpec", + "ResourcesConsumed", + "ManualBatchTuningParameters", "MigratableResource", - "SearchMigratableResourcesRequest", - "SearchMigratableResourcesResponse", + "BatchMigrateResourcesOperationMetadata", "BatchMigrateResourcesRequest", - "MigrateResourceRequest", "BatchMigrateResourcesResponse", + "MigrateResourceRequest", "MigrateResourceResponse", - "BatchMigrateResourcesOperationMetadata", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "Model", + "ModelContainerSpec", + "Port", + "PredictSchemata", "ModelEvaluation", "ModelEvaluationSlice", - "UploadModelRequest", - "UploadModelOperationMetadata", - "UploadModelResponse", - "GetModelRequest", - "ListModelsRequest", - "ListModelsResponse", - "UpdateModelRequest", "DeleteModelRequest", - "ExportModelRequest", "ExportModelOperationMetadata", + "ExportModelRequest", "ExportModelResponse", "GetModelEvaluationRequest", - "ListModelEvaluationsRequest", - "ListModelEvaluationsResponse", "GetModelEvaluationSliceRequest", + "GetModelRequest", "ListModelEvaluationSlicesRequest", "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "DeleteOperationMetadata", + "GenericOperationMetadata", + "CancelTrainingPipelineRequest", "CreateTrainingPipelineRequest", + "DeleteTrainingPipelineRequest", "GetTrainingPipelineRequest", "ListTrainingPipelinesRequest", "ListTrainingPipelinesResponse", - "DeleteTrainingPipelineRequest", - "CancelTrainingPipelineRequest", - "PredictRequest", - "PredictResponse", + "PipelineState", "ExplainRequest", "ExplainResponse", - "CreateSpecialistPoolRequest", + "PredictRequest", + "PredictResponse", + "SpecialistPool", "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "DeleteSpecialistPoolRequest", "GetSpecialistPoolRequest", "ListSpecialistPoolsRequest", "ListSpecialistPoolsResponse", - "DeleteSpecialistPoolRequest", - "UpdateSpecialistPoolRequest", "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "Measurement", + "Study", + "StudySpec", + "Trial", + "FilterSplit", + "FractionSplit", + "InputDataConfig", + "PredefinedSplit", + "TimestampSplit", + "TrainingPipeline", + "UserActionReference", + "AddTrialMeasurementRequest", + "CheckTrialEarlyStoppingStateMetatdata", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CompleteTrialRequest", + "CreateStudyRequest", + "CreateTrialRequest", + "DeleteStudyRequest", + "DeleteTrialRequest", + "GetStudyRequest", + "GetTrialRequest", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", + "ListStudiesRequest", + "ListStudiesResponse", + "ListTrialsRequest", + "ListTrialsResponse", + "LookupStudyRequest", + "StopTrialRequest", + "SuggestTrialsMetadata", + "SuggestTrialsRequest", + "SuggestTrialsResponse", ) diff --git a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py index 337b0eeaf5..8c6968952c 100644 --- a/google/cloud/aiplatform_v1beta1/types/accelerator_type.py +++ b/google/cloud/aiplatform_v1beta1/types/accelerator_type.py @@ -31,8 +31,6 @@ class AcceleratorType(proto.Enum): NVIDIA_TESLA_V100 = 3 NVIDIA_TESLA_P4 = 4 NVIDIA_TESLA_T4 = 5 - TPU_V2 = 6 - TPU_V3 = 7 __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index 74bb5eac98..a42ef0da82 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -56,7 +56,7 @@ class Annotation(proto.Message): Output only. Timestamp when this Annotation was last updated. etag (str): - Optional. Used to perform a consistent read- + Optional. Used to perform consistent read- odify-write updates. If not set, a blind "overwrite" update happens. annotation_source (google.cloud.aiplatform_v1beta1.types.UserActionReference): @@ -78,7 +78,7 @@ class Annotation(proto.Message): - "aiplatform.googleapis.com/annotation_set_name": optional, name of the UI's annotation set this Annotation - belongs to. If not set the Annotation is not visible in + belongs to. If not set, the Annotation is not visible in the UI. - "aiplatform.googleapis.com/payload_schema": output only, diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index 9d35539a5b..e921e25971 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -46,7 +46,7 @@ class AnnotationSpec(proto.Message): Output only. Timestamp when AnnotationSpec was last updated. etag (str): - Optional. Used to perform a consistent read- + Optional. Used to perform consistent read- odify-write updates. If not set, a blind "overwrite" update happens. """ diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index a51de2f9a2..9c79349b9e 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -117,8 +117,10 @@ class BatchPredictionJob(proto.Message): - ``csv``: Generating explanations for CSV format is not supported. - If this field is set to true, the + If this field is set to true, either the ``Model.explanation_spec`` + or + ``explanation_spec`` must be populated. explanation_spec (google.cloud.aiplatform_v1beta1.types.ExplanationSpec): Explanation configuration for this BatchPredictionJob. Can @@ -288,7 +290,6 @@ class OutputConfig(proto.Message): Required. The format in which AI Platform gives the predictions, must be one of the [Model's][google.cloud.aiplatform.v1beta1.BatchPredictionJob.model] - ``supported_output_storage_formats``. """ diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index a1674327ef..1d148b7777 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -122,7 +122,9 @@ class CustomJobSpec(proto.Message): Attributes: worker_pool_specs (Sequence[google.cloud.aiplatform_v1beta1.types.WorkerPoolSpec]): Required. The spec of the worker pools - including machine type and Docker image. + including machine type and Docker image. All + worker pools except the first one are optional + and can be skipped by providing an empty value. scheduling (google.cloud.aiplatform_v1beta1.types.Scheduling): Scheduling options for a CustomJob. service_account (str): @@ -256,12 +258,13 @@ class PythonPackageSpec(proto.Message): Attributes: executor_image_uri (str): - Required. The URI of a container image in the - Container Registry that will run the provided - python package. AI Platform provides wide range - of executor images with pre-installed packages - to meet users' various use cases. Only one of - the provided images can be set here. + Required. The URI of a container image in Artifact Registry + that will run the provided Python package. AI Platform + provides a wide range of executor images with pre-installed + packages to meet users' various use cases. See the list of + `pre-built containers for + training `__. + You must use an image from this list. package_uris (Sequence[str]): Required. The Google Cloud Storage location of the Python package files which are the diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index eff2516bda..a12776f06c 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -63,7 +63,7 @@ class DataItem(proto.Message): schema's][google.cloud.aiplatform.v1beta1.Dataset.metadata_schema_uri] dataItemSchemaUri field. etag (str): - Optional. Used to perform a consistent read- + Optional. Used to perform consistent read- odify-write updates. If not set, a blind "overwrite" update happens. """ diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index dc06549ac4..d750f53e66 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -133,7 +133,7 @@ class DataLabelingJob(proto.Message): are associated with the EncryptionSpec of the Dataset they are exported to. active_learning_config (google.cloud.aiplatform_v1beta1.types.ActiveLearningConfig): - Parameters that configure active learning + Parameters that configure the active learning pipeline. Active learning will label the data incrementally via several iterations. For every iteration, it will select a batch of data based @@ -182,8 +182,8 @@ class DataLabelingJob(proto.Message): class ActiveLearningConfig(proto.Message): - r"""Parameters that configure active learning pipeline. Active - learning will label the data incrementally by several + r"""Parameters that configure the active learning pipeline. + Active learning will label the data incrementally by several iterations. For every iteration, it will select a batch of data based on the sampling strategy. @@ -233,7 +233,7 @@ class SampleConfig(proto.Message): in each following batch (except the first batch). sample_strategy (google.cloud.aiplatform_v1beta1.types.SampleConfig.SampleStrategy): - Field to chose sampling strategy. Sampling + Field to choose sampling strategy. Sampling strategy will decide which data should be selected for human labeling in every batch. """ diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 2ca1244527..9fa17fcb3a 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -130,7 +130,7 @@ class ImportDataConfig(proto.Message): be picked randomly. Two DataItems are considered identical if their content bytes are identical (e.g. image bytes or pdf bytes). These labels will be overridden by Annotation - labels specified inside index file refenced by + labels specified inside index file referenced by ``import_schema_uri``, e.g. jsonl file. import_schema_uri (str): diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 91c64f4b5d..1ab94b8c89 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -372,7 +372,6 @@ class GetAnnotationSpecRequest(proto.Message): Attributes: name (str): Required. The name of the AnnotationSpec resource. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. @@ -391,7 +390,6 @@ class ListAnnotationsRequest(proto.Message): parent (str): Required. The resource name of the DataItem to list Annotations from. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` filter (str): The standard list filter. diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index bdbcb6ff21..40ede068f3 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -128,9 +128,9 @@ class DeployedModel(proto.Message): id (str): Output only. The ID of the DeployedModel. model (str): - Required. The name of the Model this is the - deployment of. Note that the Model may be in a - different location than the DeployedModel's + Required. The name of the Model that this is + the deployment of. Note that the Model may be in + a different location than the DeployedModel's Endpoint. display_name (str): The display name of the DeployedModel. If not provided upon @@ -151,10 +151,11 @@ class DeployedModel(proto.Message): ``explanation_spec`` is not populated, the value of the same field of ``Model.explanation_spec`` - is inherited. The corresponding + is inherited. If the corresponding ``Model.explanation_spec`` - must be populated, otherwise explanation for this Model is - not allowed. + is not populated, all fields of the + ``explanation_spec`` + will be used for the explanation configuration. service_account (str): The service account that the DeployedModel's container runs as. Specify the email address of the service account. If diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index 72d1063334..fe7442ab2a 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -152,7 +152,7 @@ class ListEndpointsResponse(proto.Message): endpoints (Sequence[google.cloud.aiplatform_v1beta1.types.Endpoint]): List of Endpoints in the requested page. next_page_token (str): - A token to retrieve next page of results. Pass to + A token to retrieve the next page of results. Pass to ``ListEndpointsRequest.page_token`` to obtain that page. """ diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index dcbc32a4f5..69947e9b9e 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -52,8 +52,8 @@ class ExplanationMetadata(proto.Message): Required. Map from output names to output metadata. For AI Platform provided Tensorflow images, keys - can be any string user defines. - + can be any user defined string that consists of + any UTF-8 characters. For custom images, keys are the name of the output field in the prediction to be explained. @@ -368,7 +368,7 @@ class OutputMetadata(proto.Message): values. The shape of the value must be an n-dimensional array of - strings. The number of dimentions must match that of the + strings. The number of dimensions must match that of the outputs to be explained. The ``Attribution.output_display_name`` is populated by locating in the mapping with diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index b032cc2bae..3a177dcf9b 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -82,9 +82,9 @@ class BigQueryDestination(proto.Message): Required. BigQuery URI to a project or table, up to 2000 characters long. - When only project is specified, Dataset and Table is - created. When full table reference is specified, Dataset - must exist and table must not exist. + When only the project is specified, the Dataset and Table is + created. When the full table reference is specified, the + Dataset must exist and table must not exist. Accepted forms: @@ -96,7 +96,7 @@ class BigQueryDestination(proto.Message): class ContainerRegistryDestination(proto.Message): - r"""The Container Regsitry location for the container image. + r"""The Container Registry location for the container image. Attributes: output_uri (str): diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index ead3e7c765..514ca12f7a 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -151,7 +151,7 @@ class ListCustomJobsResponse(proto.Message): custom_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.CustomJob]): List of CustomJobs in the requested page. next_page_token (str): - A token to retrieve next page of results. Pass to + A token to retrieve the next page of results. Pass to ``ListCustomJobsRequest.page_token`` to obtain that page. """ @@ -219,7 +219,6 @@ class GetDataLabelingJobRequest(proto.Message): Attributes: name (str): Required. The name of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` """ @@ -311,7 +310,6 @@ class DeleteDataLabelingJobRequest(proto.Message): name (str): Required. The name of the DataLabelingJob to be deleted. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` """ @@ -325,7 +323,6 @@ class CancelDataLabelingJobRequest(proto.Message): Attributes: name (str): Required. The name of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` """ @@ -363,7 +360,6 @@ class GetHyperparameterTuningJobRequest(proto.Message): name (str): Required. The name of the HyperparameterTuningJob resource. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` """ @@ -430,7 +426,7 @@ class ListHyperparameterTuningJobsResponse(proto.Message): ``HyperparameterTuningJob.trials`` of the jobs will be not be returned. next_page_token (str): - A token to retrieve next page of results. Pass to + A token to retrieve the next page of results. Pass to ``ListHyperparameterTuningJobsRequest.page_token`` to obtain that page. """ @@ -456,7 +452,6 @@ class DeleteHyperparameterTuningJobRequest(proto.Message): name (str): Required. The name of the HyperparameterTuningJob resource to be deleted. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` """ @@ -471,7 +466,6 @@ class CancelHyperparameterTuningJobRequest(proto.Message): name (str): Required. The name of the HyperparameterTuningJob to cancel. Format: - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` """ @@ -506,7 +500,6 @@ class GetBatchPredictionJobRequest(proto.Message): name (str): Required. The name of the BatchPredictionJob resource. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` """ @@ -531,6 +524,8 @@ class ListBatchPredictionJobsRequest(proto.Message): - ``state`` supports = and !=. + - ``model_display_name`` supports = and != + Some examples of using the filter are: - ``state="JOB_STATE_SUCCEEDED" AND display_name="my_job"`` @@ -572,7 +567,7 @@ class ListBatchPredictionJobsResponse(proto.Message): List of BatchPredictionJobs in the requested page. next_page_token (str): - A token to retrieve next page of results. Pass to + A token to retrieve the next page of results. Pass to ``ListBatchPredictionJobsRequest.page_token`` to obtain that page. """ @@ -596,7 +591,6 @@ class DeleteBatchPredictionJobRequest(proto.Message): name (str): Required. The name of the BatchPredictionJob resource to be deleted. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` """ @@ -611,7 +605,6 @@ class CancelBatchPredictionJobRequest(proto.Message): name (str): Required. The name of the BatchPredictionJob to cancel. Format: - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` """ diff --git a/google/cloud/aiplatform_v1beta1/types/job_state.py b/google/cloud/aiplatform_v1beta1/types/job_state.py index f86e179b1b..b77947cc9a 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_state.py +++ b/google/cloud/aiplatform_v1beta1/types/job_state.py @@ -34,6 +34,7 @@ class JobState(proto.Enum): JOB_STATE_CANCELLING = 6 JOB_STATE_CANCELLED = 7 JOB_STATE_PAUSED = 8 + JOB_STATE_EXPIRED = 9 __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index 50dc4b3eef..c791354c58 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -32,6 +32,7 @@ "BatchDedicatedResources", "ResourcesConsumed", "DiskSpec", + "AutoscalingMetricSpec", }, ) @@ -95,15 +96,43 @@ class DedicatedResources(proto.Message): max_replica_count (int): Immutable. The maximum number of replicas this DeployedModel may be deployed on when the traffic against it increases. If - requested value is too large, the deployment will error, but - if deployment succeeds then the ability to scale the model - to that many replicas is guaranteed (barring service + the requested value is too large, the deployment will error, + but if deployment succeeds then the ability to scale the + model to that many replicas is guaranteed (barring service outages). If traffic against the DeployedModel increases beyond what its replicas at maximum may handle, a portion of the traffic will be dropped. If this value is not provided, will use ``min_replica_count`` as the default value. + autoscaling_metric_specs (Sequence[google.cloud.aiplatform_v1beta1.types.AutoscalingMetricSpec]): + Immutable. The metric specifications that overrides a + resource utilization metric (CPU utilization, accelerator's + duty cycle, and so on) target value (default to 60 if not + set). At most one entry is allowed per metric. + + If + ``machine_spec.accelerator_count`` + is above 0, the autoscaling will be based on both CPU + utilization and accelerator's duty cycle metrics and scale + up when either metrics exceeds its target value while scale + down if both metrics are under their target value. The + default target value is 60 for both metrics. + + If + ``machine_spec.accelerator_count`` + is 0, the autoscaling will be based on CPU utilization + metric only with default target value 60 if not explicitly + set. + + For example, in the case of Online Prediction, if you want + to override target CPU utilization to 80, you should set + ``autoscaling_metric_specs.metric_name`` + to + ``aiplatform.googleapis.com/prediction/online/cpu/utilization`` + and + ``autoscaling_metric_specs.target`` + to ``80``. """ machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) @@ -112,6 +141,10 @@ class DedicatedResources(proto.Message): max_replica_count = proto.Field(proto.INT32, number=3) + autoscaling_metric_specs = proto.RepeatedField( + proto.MESSAGE, number=4, message="AutoscalingMetricSpec", + ) + class AutomaticResources(proto.Message): r"""A description of resources that to large degree are decided @@ -126,20 +159,20 @@ class AutomaticResources(proto.Message): it may dynamically be deployed onto more replicas up to ``max_replica_count``, and as traffic decreases, some of these extra replicas may - be freed. If requested value is too large, the deployment - will error. + be freed. If the requested value is too large, the + deployment will error. max_replica_count (int): Immutable. The maximum number of replicas this DeployedModel may be deployed on when the - traffic against it increases. If requested value - is too large, the deployment will error, but if - deployment succeeds then the ability to scale - the model to that many replicas is guaranteed - (barring service outages). If traffic against - the DeployedModel increases beyond what its - replicas at maximum may handle, a portion of the - traffic will be dropped. If this value is not - provided, a no upper bound for scaling under + traffic against it increases. If the requested + value is too large, the deployment will error, + but if deployment succeeds then the ability to + scale the model to that many replicas is + guaranteed (barring service outages). If traffic + against the DeployedModel increases beyond what + its replicas at maximum may handle, a portion of + the traffic will be dropped. If this value is + not provided, a no upper bound for scaling under heavy traffic will be assume, though AI Platform may be unable to scale beyond certain replica number. @@ -211,4 +244,30 @@ class DiskSpec(proto.Message): boot_disk_size_gb = proto.Field(proto.INT32, number=2) +class AutoscalingMetricSpec(proto.Message): + r"""The metric specification that defines the target resource + utilization (CPU utilization, accelerator's duty cycle, and so + on) for calculating the desired replica count. + + Attributes: + metric_name (str): + Required. The resource metric name. Supported metrics: + + - For Online Prediction: + - ``aiplatform.googleapis.com/prediction/online/accelerator/duty_cycle`` + - ``aiplatform.googleapis.com/prediction/online/cpu/utilization`` + target (int): + The target resource utilization in percentage + (1% - 100%) for the given metric; once the real + usage deviates from the target by a certain + percentage, the machine replicas change. The + default value is 60 (representing 60%) if not + provided. + """ + + metric_name = proto.Field(proto.STRING, number=1) + + target = proto.Field(proto.INT32, number=2) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py index 144ff94acc..9a695ea349 100644 --- a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py +++ b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py @@ -44,10 +44,10 @@ class MigratableResource(proto.Message): Output only. Represents one Dataset in datalabeling.googleapis.com. last_migrate_time (google.protobuf.timestamp_pb2.Timestamp): - Output only. Timestamp when last migrate - attempt on this MigratableResource started. Will - not be set if there's no migrate attempt on this - MigratableResource. + Output only. Timestamp when the last + migration attempt on this MigratableResource + started. Will not be set if there's no migration + attempt on this MigratableResource. last_update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this MigratableResource was last updated. @@ -130,7 +130,6 @@ class DataLabelingAnnotatedDataset(proto.Message): annotated_dataset (str): Full resource name of data labeling AnnotatedDataset. Format: - ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. annotated_dataset_display_name (str): The AnnotatedDataset's display name in diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py index ef37006233..de4c9466f6 100644 --- a/google/cloud/aiplatform_v1beta1/types/migration_service.py +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -71,7 +71,7 @@ class SearchMigratableResourcesRequest(proto.Message): - ``last_migrate_time:*`` will filter migrated resources. - ``NOT last_migrate_time:*`` will filter not yet - migrated resource. + migrated resources. """ parent = proto.Field(proto.STRING, number=1) @@ -247,7 +247,6 @@ class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): annotated_dataset (str): Required. Full resource name of data labeling AnnotatedDataset. Format: - ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. """ @@ -331,12 +330,12 @@ class BatchMigrateResourcesOperationMetadata(proto.Message): generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The common part of the operation metadata. partial_results (Sequence[google.cloud.aiplatform_v1beta1.types.BatchMigrateResourcesOperationMetadata.PartialResult]): - Partial results that reflects the latest + Partial results that reflect the latest migration operation progress. """ class PartialResult(proto.Message): - r"""Represents a partial result in batch migration opreation for one + r"""Represents a partial result in batch migration operation for one ``MigrateResourceRequest``. Attributes: diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index c9e1ddd68a..4dcf6baefa 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -58,7 +58,7 @@ class Model(proto.Message): 3.0.2 `Schema Object `__. AutoML Models always have this field populated by AI - Platform, if no additional metadata is needed this field is + Platform, if no additional metadata is needed, this field is set to an empty string. Note: The URI given on output will be immutable and probably different, including the URI scheme, than the one given on input. The output URI will @@ -205,8 +205,8 @@ class Model(proto.Message): The Model can be used for [requesting explanation][PredictionService.Explain] after being ``deployed`` - iff it is populated. The Model can be used for [batch - explanation][BatchPredictionJob.generate_explanation] iff it + if it is populated. The Model can be used for [batch + explanation][BatchPredictionJob.generate_explanation] if it is populated. All fields of the explanation_spec can be overridden by @@ -217,6 +217,19 @@ class Model(proto.Message): ``explanation_spec`` of ``BatchPredictionJob``. + + If the default explanation specification is not set for this + Model, this Model can still be used for [requesting + explanation][PredictionService.Explain] by setting + ``explanation_spec`` + of + ``DeployModelRequest.deployed_model`` + and for [batch + explanation][BatchPredictionJob.generate_explanation] by + setting + ``explanation_spec`` + of + ``BatchPredictionJob``. etag (str): Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update @@ -244,7 +257,7 @@ class DeploymentResourcesType(proto.Enum): AUTOMATIC_RESOURCES = 2 class ExportFormat(proto.Message): - r"""Represents a supported by the Model export format. + r"""Represents export format supported by the Model. All formats export to Google Cloud Storage. Attributes: @@ -372,8 +385,8 @@ class PredictSchemata(proto.Message): The schema is defined as an OpenAPI 3.0.2 `Schema Object `__. AutoML Models always have this field populated by AI - Platform, if no parameters are supported it is set to an - empty string. Note: The URI given on output will be + Platform, if no parameters are supported, then it is set to + an empty string. Note: The URI given on output will be immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. @@ -424,6 +437,11 @@ class ModelContainerSpec(proto.Message): To learn about the requirements for the Docker image itself, see `Custom container requirements `__. + + You can use the URI to one of AI Platform's `pre-built + container images for + prediction `__ + in this field. command (Sequence[str]): Immutable. Specifies the command that runs when the container starts. This overrides the container's @@ -596,7 +614,7 @@ class ModelContainerSpec(proto.Message): ```AIP_DEPLOYED_MODEL_ID`` environment variable `__.) health_route (str): - Immutable. HTTP path on the container to send health checkss + Immutable. HTTP path on the container to send health checks to. AI Platform intermittently sends GET requests to this path on the container's IP address and port to check that the container is healthy. Read more about `health diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index 1091c71148..e0d8e148ab 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -329,7 +329,6 @@ class GetModelEvaluationRequest(proto.Message): Attributes: name (str): Required. The name of the ModelEvaluation resource. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` """ @@ -403,7 +402,6 @@ class GetModelEvaluationSliceRequest(proto.Message): name (str): Required. The name of the ModelEvaluationSlice resource. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` """ @@ -418,7 +416,6 @@ class ListModelEvaluationSlicesRequest(proto.Message): parent (str): Required. The resource name of the ModelEvaluation to list the ModelEvaluationSlices from. Format: - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` filter (str): The standard list filter. diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index 32eec8489c..b06361dfa9 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -64,7 +64,6 @@ class GetTrainingPipelineRequest(proto.Message): Attributes: name (str): Required. The name of the TrainingPipeline resource. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` """ @@ -128,7 +127,7 @@ class ListTrainingPipelinesResponse(proto.Message): List of TrainingPipelines in the requested page. next_page_token (str): - A token to retrieve next page of results. Pass to + A token to retrieve the next page of results. Pass to ``ListTrainingPipelinesRequest.page_token`` to obtain that page. """ @@ -152,7 +151,6 @@ class DeleteTrainingPipelineRequest(proto.Message): name (str): Required. The name of the TrainingPipeline resource to be deleted. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` """ @@ -167,7 +165,6 @@ class CancelTrainingPipelineRequest(proto.Message): name (str): Required. The name of the TrainingPipeline to cancel. Format: - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` """ diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index 9d0241080b..3ed6593bd6 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -80,7 +80,6 @@ class GetSpecialistPoolRequest(proto.Message): name (str): Required. The name of the SpecialistPool resource. The form is - ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. """ @@ -189,7 +188,6 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): specialist_pool (str): Output only. The name of the SpecialistPool to which the specialists are being added. Format: - ``projects/{project_id}/locations/{location_id}/specialistPools/{specialist_pool}`` generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 4f8b972746..092d3a3e2d 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -24,16 +24,63 @@ __protobuf__ = proto.module( package="google.cloud.aiplatform.v1beta1", - manifest={"Trial", "StudySpec", "Measurement",}, + manifest={"Study", "Trial", "StudySpec", "Measurement",}, ) +class Study(proto.Message): + r"""A message representing a Study. + + Attributes: + name (str): + Output only. The name of a study. The study's globally + unique identifier. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + display_name (str): + Required. Describes the Study, default value + is empty string. + study_spec (google.cloud.aiplatform_v1beta1.types.StudySpec): + Required. Configuration of the Study. + state (google.cloud.aiplatform_v1beta1.types.Study.State): + Output only. The detailed state of a Study. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time at which the study was + created. + inactive_reason (str): + Output only. A human readable reason why the + Study is inactive. This should be empty if a + study is ACTIVE or COMPLETED. + """ + + class State(proto.Enum): + r"""Describes the Study state.""" + STATE_UNSPECIFIED = 0 + ACTIVE = 1 + INACTIVE = 2 + COMPLETED = 3 + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + study_spec = proto.Field(proto.MESSAGE, number=3, message="StudySpec",) + + state = proto.Field(proto.ENUM, number=4, enum=State,) + + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + inactive_reason = proto.Field(proto.STRING, number=6) + + class Trial(proto.Message): r"""A message representing a Trial. A Trial contains a unique set of Parameters that has been or will be evaluated, along with the objective metrics got by running the Trial. Attributes: + name (str): + Output only. Resource name of the Trial + assigned by the service. id (str): Output only. The identifier of the Trial assigned by the service. @@ -84,6 +131,8 @@ class Parameter(proto.Message): value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) + name = proto.Field(proto.STRING, number=1) + id = proto.Field(proto.STRING, number=2) state = proto.Field(proto.ENUM, number=3, enum=State,) @@ -103,6 +152,15 @@ class StudySpec(proto.Message): r"""Represents specification of a Study. Attributes: + decay_curve_stopping_spec (google.cloud.aiplatform_v1beta1.types.StudySpec.DecayCurveAutomatedStoppingSpec): + The automated early stopping spec using decay + curve rule. + median_automated_stopping_spec (google.cloud.aiplatform_v1beta1.types.StudySpec.MedianAutomatedStoppingSpec): + The automated early stopping spec using + median rule. + convex_stop_config (google.cloud.aiplatform_v1beta1.types.StudySpec.ConvexStopConfig): + The automated early stopping using convex + stopping rule. metrics (Sequence[google.cloud.aiplatform_v1beta1.types.StudySpec.MetricSpec]): Required. Metric specs for the Study. parameters (Sequence[google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec]): @@ -391,6 +449,113 @@ class CategoricalValueCondition(proto.Message): message="StudySpec.ParameterSpec.ConditionalParameterSpec", ) + class DecayCurveAutomatedStoppingSpec(proto.Message): + r"""The decay curve automated stopping rule builds a Gaussian + Process Regressor to predict the final objective value of a + Trial based on the already completed Trials and the intermediate + measurements of the current Trial. Early stopping is requested + for the current Trial if there is very low probability to exceed + the optimal value found so far. + + Attributes: + use_elapsed_duration (bool): + True if + ``Measurement.elapsed_duration`` + is used as the x-axis of each Trials Decay Curve. Otherwise, + ``Measurement.step_count`` + will be used as the x-axis. + """ + + use_elapsed_duration = proto.Field(proto.BOOL, number=1) + + class MedianAutomatedStoppingSpec(proto.Message): + r"""The median automated stopping rule stops a pending Trial if the + Trial's best objective_value is strictly below the median + 'performance' of all completed Trials reported up to the Trial's + last measurement. Currently, 'performance' refers to the running + average of the objective values reported by the Trial in each + measurement. + + Attributes: + use_elapsed_duration (bool): + True if median automated stopping rule applies on + ``Measurement.elapsed_duration``. + It means that elapsed_duration field of latest measurement + of current Trial is used to compute median objective value + for each completed Trials. + """ + + use_elapsed_duration = proto.Field(proto.BOOL, number=1) + + class ConvexStopConfig(proto.Message): + r"""Configuration for ConvexStopPolicy. + + Attributes: + max_num_steps (int): + Steps used in predicting the final objective for early + stopped trials. In general, it's set to be the same as the + defined steps in training / tuning. When use_steps is false, + this field is set to the maximum elapsed seconds. + min_num_steps (int): + Minimum number of steps for a trial to complete. Trials + which do not have a measurement with num_steps > + min_num_steps won't be considered for early stopping. It's + ok to set it to 0, and a trial can be early stopped at any + stage. By default, min_num_steps is set to be one-tenth of + the max_num_steps. When use_steps is false, this field is + set to the minimum elapsed seconds. + autoregressive_order (int): + The number of Trial measurements used in + autoregressive model for value prediction. A + trial won't be considered early stopping if has + fewer measurement points. + learning_rate_parameter_name (str): + The hyper-parameter name used in the tuning job that stands + for learning rate. Leave it blank if learning rate is not in + a parameter in tuning. The learning_rate is used to estimate + the objective value of the ongoing trial. + use_seconds (bool): + This bool determines whether or not the rule is applied + based on elapsed_secs or steps. If use_seconds==false, the + early stopping decision is made according to the predicted + objective values according to the target steps. If + use_seconds==true, elapsed_secs is used instead of steps. + Also, in this case, the parameters max_num_steps and + min_num_steps are overloaded to contain max_elapsed_seconds + and min_elapsed_seconds. + """ + + max_num_steps = proto.Field(proto.INT64, number=1) + + min_num_steps = proto.Field(proto.INT64, number=2) + + autoregressive_order = proto.Field(proto.INT64, number=3) + + learning_rate_parameter_name = proto.Field(proto.STRING, number=4) + + use_seconds = proto.Field(proto.BOOL, number=5) + + decay_curve_stopping_spec = proto.Field( + proto.MESSAGE, + number=4, + oneof="automated_stopping_spec", + message=DecayCurveAutomatedStoppingSpec, + ) + + median_automated_stopping_spec = proto.Field( + proto.MESSAGE, + number=5, + oneof="automated_stopping_spec", + message=MedianAutomatedStoppingSpec, + ) + + convex_stop_config = proto.Field( + proto.MESSAGE, + number=8, + oneof="automated_stopping_spec", + message=ConvexStopConfig, + ) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index 6175a12e96..3c03b0f47d 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -218,17 +218,15 @@ class InputDataConfig(proto.Message): - AIP_DATA_FORMAT = "jsonl" for non-tabular data, "csv" for tabular data - - AIP_TRAINING_DATA_URI = - "gcs_destination/dataset---/training-*.${AIP_DATA_FORMAT}" + - AIP_TRAINING_DATA_URI = + "gcs_destination/dataset---/training-*.${AIP_DATA_FORMAT}" - AIP_VALIDATION_DATA_URI = - - "gcs_destination/dataset---/validation-*.${AIP_DATA_FORMAT}" + "gcs_destination/dataset---/validation-*.${AIP_DATA_FORMAT}" - AIP_TEST_DATA_URI = - - "gcs_destination/dataset---/test-*.${AIP_DATA_FORMAT}". + "gcs_destination/dataset---/test-*.${AIP_DATA_FORMAT}". bigquery_destination (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): Only applicable to custom training with tabular Dataset with BigQuery source. @@ -243,13 +241,12 @@ class InputDataConfig(proto.Message): ``validation`` and ``test``. - AIP_DATA_FORMAT = "bigquery". - - AIP_TRAINING_DATA_URI = - "bigquery_destination.dataset\_\ **\ .training" + - AIP_TRAINING_DATA_URI = + "bigquery_destination.dataset\_\ **\ .training" - AIP_VALIDATION_DATA_URI = - - "bigquery_destination.dataset\_\ **\ .validation" + "bigquery_destination.dataset\_\ **\ .validation" - AIP_TEST_DATA_URI = "bigquery_destination.dataset\_\ **\ .test". diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index 710e4a6d16..25180ae567 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -37,7 +37,6 @@ class UserActionReference(proto.Message): data_labeling_job (str): For API calls that start a LabelingJob. Resource name of the LabelingJob. Format: - 'projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}' method (str): The method name of the API call. For example, diff --git a/google/cloud/aiplatform_v1beta1/types/vizier_service.py b/google/cloud/aiplatform_v1beta1/types/vizier_service.py new file mode 100644 index 0000000000..2b837c476e --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/vizier_service.py @@ -0,0 +1,479 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import operation +from google.cloud.aiplatform_v1beta1.types import study as gca_study +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "GetStudyRequest", + "CreateStudyRequest", + "ListStudiesRequest", + "ListStudiesResponse", + "DeleteStudyRequest", + "LookupStudyRequest", + "SuggestTrialsRequest", + "SuggestTrialsResponse", + "SuggestTrialsMetadata", + "CreateTrialRequest", + "GetTrialRequest", + "ListTrialsRequest", + "ListTrialsResponse", + "AddTrialMeasurementRequest", + "CompleteTrialRequest", + "DeleteTrialRequest", + "CheckTrialEarlyStoppingStateRequest", + "CheckTrialEarlyStoppingStateResponse", + "CheckTrialEarlyStoppingStateMetatdata", + "StopTrialRequest", + "ListOptimalTrialsRequest", + "ListOptimalTrialsResponse", + }, +) + + +class GetStudyRequest(proto.Message): + r"""Request message for + ``VizierService.GetStudy``. + + Attributes: + name (str): + Required. The name of the Study resource. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateStudyRequest(proto.Message): + r"""Request message for + ``VizierService.CreateStudy``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + CustomJob in. Format: + ``projects/{project}/locations/{location}`` + study (google.cloud.aiplatform_v1beta1.types.Study): + Required. The Study configuration used to + create the Study. + """ + + parent = proto.Field(proto.STRING, number=1) + + study = proto.Field(proto.MESSAGE, number=2, message=gca_study.Study,) + + +class ListStudiesRequest(proto.Message): + r"""Request message for + ``VizierService.ListStudies``. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + Study from. Format: + ``projects/{project}/locations/{location}`` + page_token (str): + Optional. A page token to request the next + page of results. If unspecified, there are no + subsequent pages. + page_size (int): + Optional. The maximum number of studies to + return per "page" of results. If unspecified, + service will pick an appropriate default. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_token = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + +class ListStudiesResponse(proto.Message): + r"""Response message for + ``VizierService.ListStudies``. + + Attributes: + studies (Sequence[google.cloud.aiplatform_v1beta1.types.Study]): + The studies associated with the project. + next_page_token (str): + Passes this token as the ``page_token`` field of the request + for a subsequent call. If this field is omitted, there are + no subsequent pages. + """ + + @property + def raw_page(self): + return self + + studies = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Study,) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteStudyRequest(proto.Message): + r"""Request message for + ``VizierService.DeleteStudy``. + + Attributes: + name (str): + Required. The name of the Study resource to be deleted. + Format: + ``projects/{project}/locations/{location}/studies/{study}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class LookupStudyRequest(proto.Message): + r"""Request message for + ``VizierService.LookupStudy``. + + Attributes: + parent (str): + Required. The resource name of the Location to get the Study + from. Format: ``projects/{project}/locations/{location}`` + display_name (str): + Required. The user-defined display name of + the Study + """ + + parent = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + +class SuggestTrialsRequest(proto.Message): + r"""Request message for + ``VizierService.SuggestTrials``. + + Attributes: + parent (str): + Required. The project and location that the Study belongs + to. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + suggestion_count (int): + Required. The number of suggestions + requested. + client_id (str): + Required. The identifier of the client that is requesting + the suggestion. + + If multiple SuggestTrialsRequests have the same + ``client_id``, the service will return the identical + suggested Trial if the Trial is pending, and provide a new + Trial if the last suggested Trial was completed. + """ + + parent = proto.Field(proto.STRING, number=1) + + suggestion_count = proto.Field(proto.INT32, number=2) + + client_id = proto.Field(proto.STRING, number=3) + + +class SuggestTrialsResponse(proto.Message): + r"""Response message for + ``VizierService.SuggestTrials``. + + Attributes: + trials (Sequence[google.cloud.aiplatform_v1beta1.types.Trial]): + A list of Trials. + study_state (google.cloud.aiplatform_v1beta1.types.Study.State): + The state of the Study. + start_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which the operation was started. + end_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which operation processing + completed. + """ + + trials = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Trial,) + + study_state = proto.Field(proto.ENUM, number=2, enum=gca_study.Study.State,) + + start_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + +class SuggestTrialsMetadata(proto.Message): + r"""Details of operations that perform Trials suggestion. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for suggesting Trials. + client_id (str): + The identifier of the client that is requesting the + suggestion. + + If multiple SuggestTrialsRequests have the same + ``client_id``, the service will return the identical + suggested Trial if the Trial is pending, and provide a new + Trial if the last suggested Trial was completed. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + client_id = proto.Field(proto.STRING, number=2) + + +class CreateTrialRequest(proto.Message): + r"""Request message for + ``VizierService.CreateTrial``. + + Attributes: + parent (str): + Required. The resource name of the Study to create the Trial + in. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + trial (google.cloud.aiplatform_v1beta1.types.Trial): + Required. The Trial to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + trial = proto.Field(proto.MESSAGE, number=2, message=gca_study.Trial,) + + +class GetTrialRequest(proto.Message): + r"""Request message for + ``VizierService.GetTrial``. + + Attributes: + name (str): + Required. The name of the Trial resource. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListTrialsRequest(proto.Message): + r"""Request message for + ``VizierService.ListTrials``. + + Attributes: + parent (str): + Required. The resource name of the Study to list the Trial + from. Format: + ``projects/{project}/locations/{location}/studies/{study}`` + page_token (str): + Optional. A page token to request the next + page of results. If unspecified, there are no + subsequent pages. + page_size (int): + Optional. The number of Trials to retrieve + per "page" of results. If unspecified, the + service will pick an appropriate default. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_token = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + +class ListTrialsResponse(proto.Message): + r"""Response message for + ``VizierService.ListTrials``. + + Attributes: + trials (Sequence[google.cloud.aiplatform_v1beta1.types.Trial]): + The Trials associated with the Study. + next_page_token (str): + Pass this token as the ``page_token`` field of the request + for a subsequent call. If this field is omitted, there are + no subsequent pages. + """ + + @property + def raw_page(self): + return self + + trials = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_study.Trial,) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class AddTrialMeasurementRequest(proto.Message): + r"""Request message for + ``VizierService.AddTrialMeasurement``. + + Attributes: + trial_name (str): + Required. The name of the trial to add measurement. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + measurement (google.cloud.aiplatform_v1beta1.types.Measurement): + Required. The measurement to be added to a + Trial. + """ + + trial_name = proto.Field(proto.STRING, number=1) + + measurement = proto.Field(proto.MESSAGE, number=3, message=gca_study.Measurement,) + + +class CompleteTrialRequest(proto.Message): + r"""Request message for + ``VizierService.CompleteTrial``. + + Attributes: + name (str): + Required. The Trial's name. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + final_measurement (google.cloud.aiplatform_v1beta1.types.Measurement): + Optional. If provided, it will be used as the completed + Trial's final_measurement; Otherwise, the service will + auto-select a previously reported measurement as the + final-measurement + trial_infeasible (bool): + Optional. True if the Trial cannot be run with the given + Parameter, and final_measurement will be ignored. + infeasible_reason (str): + Optional. A human readable reason why the trial was + infeasible. This should only be provided if + ``trial_infeasible`` is true. + """ + + name = proto.Field(proto.STRING, number=1) + + final_measurement = proto.Field( + proto.MESSAGE, number=2, message=gca_study.Measurement, + ) + + trial_infeasible = proto.Field(proto.BOOL, number=3) + + infeasible_reason = proto.Field(proto.STRING, number=4) + + +class DeleteTrialRequest(proto.Message): + r"""Request message for + ``VizierService.DeleteTrial``. + + Attributes: + name (str): + Required. The Trial's name. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CheckTrialEarlyStoppingStateRequest(proto.Message): + r"""Request message for + ``VizierService.CheckTrialEarlyStoppingState``. + + Attributes: + trial_name (str): + Required. The Trial's name. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + """ + + trial_name = proto.Field(proto.STRING, number=1) + + +class CheckTrialEarlyStoppingStateResponse(proto.Message): + r"""Response message for + ``VizierService.CheckTrialEarlyStoppingState``. + + Attributes: + should_stop (bool): + True if the Trial should stop. + """ + + should_stop = proto.Field(proto.BOOL, number=1) + + +class CheckTrialEarlyStoppingStateMetatdata(proto.Message): + r"""This message will be placed in the metadata field of a + google.longrunning.Operation associated with a + CheckTrialEarlyStoppingState request. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): + Operation metadata for suggesting Trials. + study (str): + The name of the Study that the Trial belongs + to. + trial (str): + The Trial name. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + study = proto.Field(proto.STRING, number=2) + + trial = proto.Field(proto.STRING, number=3) + + +class StopTrialRequest(proto.Message): + r"""Request message for + ``VizierService.StopTrial``. + + Attributes: + name (str): + Required. The Trial's name. Format: + ``projects/{project}/locations/{location}/studies/{study}/trials/{trial}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListOptimalTrialsRequest(proto.Message): + r"""Request message for + ``VizierService.ListOptimalTrials``. + + Attributes: + parent (str): + Required. The name of the Study that the + optimal Trial belongs to. + """ + + parent = proto.Field(proto.STRING, number=1) + + +class ListOptimalTrialsResponse(proto.Message): + r"""Response message for + ``VizierService.ListOptimalTrials``. + + Attributes: + optimal_trials (Sequence[google.cloud.aiplatform_v1beta1.types.Trial]): + The pareto-optimal Trials for multiple objective Study or + the optimal trial for single objective Study. The definition + of pareto-optimal can be checked in wiki page. + https://en.wikipedia.org/wiki/Pareto_efficiency + """ + + optimal_trials = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_study.Trial, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/noxfile.py b/noxfile.py index 6fd80b539a..35270f664f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -41,6 +41,9 @@ "docs", ] +# Error if a python version is missing +nox.options.error_on_missing_interpreters = True + @nox.session(python=DEFAULT_PYTHON_VERSION) def lint(session): @@ -93,6 +96,7 @@ def default(session): session.run( "py.test", "--quiet", + f"--junitxml=unit_{session.python}_sponge_log.xml", "--cov=google/cloud", "--cov=tests/unit", "--cov-append", @@ -122,6 +126,9 @@ def system(session): # Sanity check: Only run tests if the environment variable is set. if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): session.skip("Credentials must be set via environment variable") + # Install pyopenssl for mTLS testing. + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": + session.install("pyopenssl") system_test_exists = os.path.exists(system_test_path) system_test_folder_exists = os.path.exists(system_test_folder_path) @@ -141,9 +148,21 @@ def system(session): # Run py.test against the system tests. if system_test_exists: - session.run("py.test", "--quiet", system_test_path, *session.posargs) + session.run( + "py.test", + "--quiet", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_path, + *session.posargs, + ) if system_test_folder_exists: - session.run("py.test", "--quiet", system_test_folder_path, *session.posargs) + session.run( + "py.test", + "--quiet", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_folder_path, + *session.posargs, + ) @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/renovate.json b/renovate.json index 4fa949311b..f08bc22c9a 100644 --- a/renovate.json +++ b/renovate.json @@ -1,5 +1,6 @@ { "extends": [ "config:base", ":preserveSemverRanges" - ] + ], + "ignorePaths": [".pre-commit-config.yaml"] } diff --git a/tests/unit/gapic/aiplatform_v1/__init__.py b/tests/unit/gapic/aiplatform_v1/__init__.py index 8b13789179..42ffdf2bc4 100644 --- a/tests/unit/gapic/aiplatform_v1/__init__.py +++ b/tests/unit/gapic/aiplatform_v1/__init__.py @@ -1 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py index d03570e876..1597014605 100644 --- a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -101,21 +101,25 @@ def test__get_default_mtls_endpoint(): ) -def test_dataset_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], +) +def test_dataset_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = DatasetServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], ) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -125,9 +129,11 @@ def test_dataset_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -485,6 +491,22 @@ def test_create_dataset_from_dict(): test_create_dataset(request_type=dict) +def test_create_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + client.create_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.CreateDatasetRequest() + + @pytest.mark.asyncio async def test_create_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest @@ -697,6 +719,22 @@ def test_get_dataset_from_dict(): test_get_dataset(request_type=dict) +def test_get_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + client.get_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetDatasetRequest() + + @pytest.mark.asyncio async def test_get_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest @@ -906,6 +944,22 @@ def test_update_dataset_from_dict(): test_update_dataset(request_type=dict) +def test_update_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + client.update_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.UpdateDatasetRequest() + + @pytest.mark.asyncio async def test_update_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest @@ -1124,6 +1178,22 @@ def test_list_datasets_from_dict(): test_list_datasets(request_type=dict) +def test_list_datasets_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + client.list_datasets() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDatasetsRequest() + + @pytest.mark.asyncio async def test_list_datasets_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest @@ -1436,6 +1506,22 @@ def test_delete_dataset_from_dict(): test_delete_dataset(request_type=dict) +def test_delete_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + client.delete_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.DeleteDatasetRequest() + + @pytest.mark.asyncio async def test_delete_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest @@ -1622,6 +1708,22 @@ def test_import_data_from_dict(): test_import_data(request_type=dict) +def test_import_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + client.import_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ImportDataRequest() + + @pytest.mark.asyncio async def test_import_data_async( transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest @@ -1834,6 +1936,22 @@ def test_export_data_from_dict(): test_export_data(request_type=dict) +def test_export_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + client.export_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ExportDataRequest() + + @pytest.mark.asyncio async def test_export_data_async( transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest @@ -2063,6 +2181,22 @@ def test_list_data_items_from_dict(): test_list_data_items(request_type=dict) +def test_list_data_items_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + client.list_data_items() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDataItemsRequest() + + @pytest.mark.asyncio async def test_list_data_items_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest @@ -2410,6 +2544,24 @@ def test_get_annotation_spec_from_dict(): test_get_annotation_spec(request_type=dict) +def test_get_annotation_spec_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + client.get_annotation_spec() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + @pytest.mark.asyncio async def test_get_annotation_spec_async( transport: str = "grpc_asyncio", @@ -2620,6 +2772,22 @@ def test_list_annotations_from_dict(): test_list_annotations(request_type=dict) +def test_list_annotations_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + client.list_annotations() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListAnnotationsRequest() + + @pytest.mark.asyncio async def test_list_annotations_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py index 227af94bf8..bf351a3978 100644 --- a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -98,21 +98,25 @@ def test__get_default_mtls_endpoint(): ) -def test_endpoint_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], +) +def test_endpoint_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = EndpointServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], ) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -122,9 +126,11 @@ def test_endpoint_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -492,6 +498,22 @@ def test_create_endpoint_from_dict(): test_create_endpoint(request_type=dict) +def test_create_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + client.create_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.CreateEndpointRequest() + + @pytest.mark.asyncio async def test_create_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest @@ -704,6 +726,22 @@ def test_get_endpoint_from_dict(): test_get_endpoint(request_type=dict) +def test_get_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + client.get_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.GetEndpointRequest() + + @pytest.mark.asyncio async def test_get_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest @@ -904,6 +942,22 @@ def test_list_endpoints_from_dict(): test_list_endpoints(request_type=dict) +def test_list_endpoints_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + client.list_endpoints() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.ListEndpointsRequest() + + @pytest.mark.asyncio async def test_list_endpoints_async( transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest @@ -1254,6 +1308,22 @@ def test_update_endpoint_from_dict(): test_update_endpoint(request_type=dict) +def test_update_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + client.update_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UpdateEndpointRequest() + + @pytest.mark.asyncio async def test_update_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest @@ -1471,6 +1541,22 @@ def test_delete_endpoint_from_dict(): test_delete_endpoint(request_type=dict) +def test_delete_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + client.delete_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeleteEndpointRequest() + + @pytest.mark.asyncio async def test_delete_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest @@ -1657,6 +1743,22 @@ def test_deploy_model_from_dict(): test_deploy_model(request_type=dict) +def test_deploy_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + client.deploy_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeployModelRequest() + + @pytest.mark.asyncio async def test_deploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest @@ -1901,6 +2003,22 @@ def test_undeploy_model_from_dict(): test_undeploy_model(request_type=dict) +def test_undeploy_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + client.undeploy_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UndeployModelRequest() + + @pytest.mark.asyncio async def test_undeploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py index a471b22658..50d1339247 100644 --- a/tests/unit/gapic/aiplatform_v1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -114,20 +114,26 @@ def test__get_default_mtls_endpoint(): assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -def test_job_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [JobServiceClient, JobServiceAsyncClient,], +) +def test_job_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = JobServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) +@pytest.mark.parametrize( + "client_class", [JobServiceClient, JobServiceAsyncClient,], +) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( @@ -136,9 +142,11 @@ def test_job_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -503,6 +511,24 @@ def test_create_custom_job_from_dict(): test_create_custom_job(request_type=dict) +def test_create_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + client.create_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateCustomJobRequest() + + @pytest.mark.asyncio async def test_create_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest @@ -734,6 +760,22 @@ def test_get_custom_job_from_dict(): test_get_custom_job(request_type=dict) +def test_get_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + client.get_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetCustomJobRequest() + + @pytest.mark.asyncio async def test_get_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest @@ -935,6 +977,22 @@ def test_list_custom_jobs_from_dict(): test_list_custom_jobs(request_type=dict) +def test_list_custom_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + client.list_custom_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListCustomJobsRequest() + + @pytest.mark.asyncio async def test_list_custom_jobs_async( transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest @@ -1263,6 +1321,24 @@ def test_delete_custom_job_from_dict(): test_delete_custom_job(request_type=dict) +def test_delete_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + client.delete_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteCustomJobRequest() + + @pytest.mark.asyncio async def test_delete_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest @@ -1461,6 +1537,24 @@ def test_cancel_custom_job_from_dict(): test_cancel_custom_job(request_type=dict) +def test_cancel_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + client.cancel_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelCustomJobRequest() + + @pytest.mark.asyncio async def test_cancel_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest @@ -1682,6 +1776,24 @@ def test_create_data_labeling_job_from_dict(): test_create_data_labeling_job(request_type=dict) +def test_create_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + client.create_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_create_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -1956,6 +2068,24 @@ def test_get_data_labeling_job_from_dict(): test_get_data_labeling_job(request_type=dict) +def test_get_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + client.get_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_get_data_labeling_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest @@ -2187,6 +2317,24 @@ def test_list_data_labeling_jobs_from_dict(): test_list_data_labeling_jobs(request_type=dict) +def test_list_data_labeling_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + client.list_data_labeling_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListDataLabelingJobsRequest() + + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async( transport: str = "grpc_asyncio", @@ -2560,6 +2708,24 @@ def test_delete_data_labeling_job_from_dict(): test_delete_data_labeling_job(request_type=dict) +def test_delete_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + client.delete_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_delete_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -2759,6 +2925,24 @@ def test_cancel_data_labeling_job_from_dict(): test_cancel_data_labeling_job(request_type=dict) +def test_cancel_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + client.cancel_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_cancel_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -2973,6 +3157,24 @@ def test_create_hyperparameter_tuning_job_from_dict(): test_create_hyperparameter_tuning_job(request_type=dict) +def test_create_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + client.create_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -3241,6 +3443,24 @@ def test_get_hyperparameter_tuning_job_from_dict(): test_get_hyperparameter_tuning_job(request_type=dict) +def test_get_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + client.get_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -3465,6 +3685,24 @@ def test_list_hyperparameter_tuning_jobs_from_dict(): test_list_hyperparameter_tuning_jobs(request_type=dict) +def test_list_hyperparameter_tuning_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + client.list_hyperparameter_tuning_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async( transport: str = "grpc_asyncio", @@ -3855,6 +4093,24 @@ def test_delete_hyperparameter_tuning_job_from_dict(): test_delete_hyperparameter_tuning_job(request_type=dict) +def test_delete_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + client.delete_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -4055,6 +4311,24 @@ def test_cancel_hyperparameter_tuning_job_from_dict(): test_cancel_hyperparameter_tuning_job(request_type=dict) +def test_cancel_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + client.cancel_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -4262,6 +4536,24 @@ def test_create_batch_prediction_job_from_dict(): test_create_batch_prediction_job(request_type=dict) +def test_create_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + client.create_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_create_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -4518,6 +4810,24 @@ def test_get_batch_prediction_job_from_dict(): test_get_batch_prediction_job(request_type=dict) +def test_get_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + client.get_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_get_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -4735,6 +5045,24 @@ def test_list_batch_prediction_jobs_from_dict(): test_list_batch_prediction_jobs(request_type=dict) +def test_list_batch_prediction_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + client.list_batch_prediction_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListBatchPredictionJobsRequest() + + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async( transport: str = "grpc_asyncio", @@ -5112,6 +5440,24 @@ def test_delete_batch_prediction_job_from_dict(): test_delete_batch_prediction_job(request_type=dict) +def test_delete_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + client.delete_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_delete_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -5311,6 +5657,24 @@ def test_cancel_batch_prediction_job_from_dict(): test_cancel_batch_prediction_job(request_type=dict) +def test_cancel_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + client.cancel_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_cancel_batch_prediction_job_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 9e1ca0513a..04bc7c392a 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -92,21 +92,25 @@ def test__get_default_mtls_endpoint(): ) -def test_migration_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], +) +def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = MigrationServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], ) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -116,9 +120,11 @@ def test_migration_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -494,6 +500,24 @@ def test_search_migratable_resources_from_dict(): test_search_migratable_resources(request_type=dict) +def test_search_migratable_resources_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + client.search_migratable_resources() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.SearchMigratableResourcesRequest() + + @pytest.mark.asyncio async def test_search_migratable_resources_async( transport: str = "grpc_asyncio", @@ -877,6 +901,24 @@ def test_batch_migrate_resources_from_dict(): test_batch_migrate_resources(request_type=dict) +def test_batch_migrate_resources_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + client.batch_migrate_resources() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.BatchMigrateResourcesRequest() + + @pytest.mark.asyncio async def test_batch_migrate_resources_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index f03d9e5d31..15e4bad05d 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -97,20 +97,26 @@ def test__get_default_mtls_endpoint(): assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -def test_model_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [ModelServiceClient, ModelServiceAsyncClient,], +) +def test_model_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = ModelServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) +@pytest.mark.parametrize( + "client_class", [ModelServiceClient, ModelServiceAsyncClient,], +) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( @@ -119,9 +125,11 @@ def test_model_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -473,6 +481,22 @@ def test_upload_model_from_dict(): test_upload_model(request_type=dict) +def test_upload_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + client.upload_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UploadModelRequest() + + @pytest.mark.asyncio async def test_upload_model_async( transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest @@ -709,6 +733,22 @@ def test_get_model_from_dict(): test_get_model(request_type=dict) +def test_get_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + client.get_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelRequest() + + @pytest.mark.asyncio async def test_get_model_async( transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest @@ -939,6 +979,22 @@ def test_list_models_from_dict(): test_list_models(request_type=dict) +def test_list_models_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + client.list_models() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelsRequest() + + @pytest.mark.asyncio async def test_list_models_async( transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest @@ -1281,6 +1337,22 @@ def test_update_model_from_dict(): test_update_model(request_type=dict) +def test_update_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + client.update_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UpdateModelRequest() + + @pytest.mark.asyncio async def test_update_model_async( transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest @@ -1520,6 +1592,22 @@ def test_delete_model_from_dict(): test_delete_model(request_type=dict) +def test_delete_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + client.delete_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.DeleteModelRequest() + + @pytest.mark.asyncio async def test_delete_model_async( transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest @@ -1706,6 +1794,22 @@ def test_export_model_from_dict(): test_export_model(request_type=dict) +def test_export_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + client.export_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ExportModelRequest() + + @pytest.mark.asyncio async def test_export_model_async( transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest @@ -1931,6 +2035,24 @@ def test_get_model_evaluation_from_dict(): test_get_model_evaluation(request_type=dict) +def test_get_model_evaluation_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + client.get_model_evaluation() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationRequest() + + @pytest.mark.asyncio async def test_get_model_evaluation_async( transport: str = "grpc_asyncio", @@ -2145,6 +2267,24 @@ def test_list_model_evaluations_from_dict(): test_list_model_evaluations(request_type=dict) +def test_list_model_evaluations_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + client.list_model_evaluations() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationsRequest() + + @pytest.mark.asyncio async def test_list_model_evaluations_async( transport: str = "grpc_asyncio", @@ -2525,6 +2665,24 @@ def test_get_model_evaluation_slice_from_dict(): test_get_model_evaluation_slice(request_type=dict) +def test_get_model_evaluation_slice_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + client.get_model_evaluation_slice() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationSliceRequest() + + @pytest.mark.asyncio async def test_get_model_evaluation_slice_async( transport: str = "grpc_asyncio", @@ -2735,6 +2893,24 @@ def test_list_model_evaluation_slices_from_dict(): test_list_model_evaluation_slices(request_type=dict) +def test_list_model_evaluation_slices_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + client.list_model_evaluation_slices() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationSlicesRequest() + + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index 23619209b0..21e6d0d44f 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -104,21 +104,25 @@ def test__get_default_mtls_endpoint(): ) -def test_pipeline_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], +) +def test_pipeline_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = PipelineServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], ) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -128,9 +132,11 @@ def test_pipeline_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -514,6 +520,24 @@ def test_create_training_pipeline_from_dict(): test_create_training_pipeline(request_type=dict) +def test_create_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + client.create_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_create_training_pipeline_async( transport: str = "grpc_asyncio", @@ -758,6 +782,24 @@ def test_get_training_pipeline_from_dict(): test_get_training_pipeline(request_type=dict) +def test_get_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + client.get_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_get_training_pipeline_async( transport: str = "grpc_asyncio", @@ -975,6 +1017,24 @@ def test_list_training_pipelines_from_dict(): test_list_training_pipelines(request_type=dict) +def test_list_training_pipelines_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + client.list_training_pipelines() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + + @pytest.mark.asyncio async def test_list_training_pipelines_async( transport: str = "grpc_asyncio", @@ -1348,6 +1408,24 @@ def test_delete_training_pipeline_from_dict(): test_delete_training_pipeline(request_type=dict) +def test_delete_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + client.delete_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_delete_training_pipeline_async( transport: str = "grpc_asyncio", @@ -1547,6 +1625,24 @@ def test_cancel_training_pipeline_from_dict(): test_cancel_training_pipeline(request_type=dict) +def test_cancel_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + client.cancel_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_cancel_training_pipeline_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py index e2be66e2c7..d5099832f0 100644 --- a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -97,21 +97,25 @@ def test__get_default_mtls_endpoint(): ) -def test_specialist_pool_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], +) +def test_specialist_pool_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = SpecialistPoolServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], ) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -121,9 +125,11 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -506,6 +512,24 @@ def test_create_specialist_pool_from_dict(): test_create_specialist_pool(request_type=dict) +def test_create_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + client.create_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_create_specialist_pool_async( transport: str = "grpc_asyncio", @@ -753,6 +777,24 @@ def test_get_specialist_pool_from_dict(): test_get_specialist_pool(request_type=dict) +def test_get_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + client.get_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_get_specialist_pool_async( transport: str = "grpc_asyncio", @@ -986,6 +1028,24 @@ def test_list_specialist_pools_from_dict(): test_list_specialist_pools(request_type=dict) +def test_list_specialist_pools_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + client.list_specialist_pools() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + + @pytest.mark.asyncio async def test_list_specialist_pools_async( transport: str = "grpc_asyncio", @@ -1376,6 +1436,24 @@ def test_delete_specialist_pool_from_dict(): test_delete_specialist_pool(request_type=dict) +def test_delete_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + client.delete_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_delete_specialist_pool_async( transport: str = "grpc_asyncio", @@ -1588,6 +1666,24 @@ def test_update_specialist_pool_from_dict(): test_update_specialist_pool(request_type=dict) +def test_update_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + client.update_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_update_specialist_pool_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1beta1/__init__.py b/tests/unit/gapic/aiplatform_v1beta1/__init__.py index 8b13789179..42ffdf2bc4 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/__init__.py +++ b/tests/unit/gapic/aiplatform_v1beta1/__init__.py @@ -1 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index fe6e04c2ec..6042fa6f42 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -103,21 +103,25 @@ def test__get_default_mtls_endpoint(): ) -def test_dataset_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], +) +def test_dataset_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = DatasetServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,], ) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -127,9 +131,11 @@ def test_dataset_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -487,6 +493,22 @@ def test_create_dataset_from_dict(): test_create_dataset(request_type=dict) +def test_create_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + client.create_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.CreateDatasetRequest() + + @pytest.mark.asyncio async def test_create_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest @@ -699,6 +721,22 @@ def test_get_dataset_from_dict(): test_get_dataset(request_type=dict) +def test_get_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + client.get_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetDatasetRequest() + + @pytest.mark.asyncio async def test_get_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest @@ -908,6 +946,22 @@ def test_update_dataset_from_dict(): test_update_dataset(request_type=dict) +def test_update_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + client.update_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.UpdateDatasetRequest() + + @pytest.mark.asyncio async def test_update_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest @@ -1126,6 +1180,22 @@ def test_list_datasets_from_dict(): test_list_datasets(request_type=dict) +def test_list_datasets_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + client.list_datasets() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDatasetsRequest() + + @pytest.mark.asyncio async def test_list_datasets_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest @@ -1438,6 +1508,22 @@ def test_delete_dataset_from_dict(): test_delete_dataset(request_type=dict) +def test_delete_dataset_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + client.delete_dataset() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.DeleteDatasetRequest() + + @pytest.mark.asyncio async def test_delete_dataset_async( transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest @@ -1624,6 +1710,22 @@ def test_import_data_from_dict(): test_import_data(request_type=dict) +def test_import_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + client.import_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ImportDataRequest() + + @pytest.mark.asyncio async def test_import_data_async( transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest @@ -1836,6 +1938,22 @@ def test_export_data_from_dict(): test_export_data(request_type=dict) +def test_export_data_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + client.export_data() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ExportDataRequest() + + @pytest.mark.asyncio async def test_export_data_async( transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest @@ -2065,6 +2183,22 @@ def test_list_data_items_from_dict(): test_list_data_items(request_type=dict) +def test_list_data_items_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + client.list_data_items() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDataItemsRequest() + + @pytest.mark.asyncio async def test_list_data_items_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest @@ -2412,6 +2546,24 @@ def test_get_annotation_spec_from_dict(): test_get_annotation_spec(request_type=dict) +def test_get_annotation_spec_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + client.get_annotation_spec() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + @pytest.mark.asyncio async def test_get_annotation_spec_async( transport: str = "grpc_asyncio", @@ -2622,6 +2774,22 @@ def test_list_annotations_from_dict(): test_list_annotations(request_type=dict) +def test_list_annotations_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + client.list_annotations() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListAnnotationsRequest() + + @pytest.mark.asyncio async def test_list_annotations_async( transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index 237d6d9268..bda98b26a5 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -103,21 +103,25 @@ def test__get_default_mtls_endpoint(): ) -def test_endpoint_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], +) +def test_endpoint_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = EndpointServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,], ) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -127,9 +131,11 @@ def test_endpoint_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -497,6 +503,22 @@ def test_create_endpoint_from_dict(): test_create_endpoint(request_type=dict) +def test_create_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + client.create_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.CreateEndpointRequest() + + @pytest.mark.asyncio async def test_create_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest @@ -709,6 +731,22 @@ def test_get_endpoint_from_dict(): test_get_endpoint(request_type=dict) +def test_get_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + client.get_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.GetEndpointRequest() + + @pytest.mark.asyncio async def test_get_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest @@ -909,6 +947,22 @@ def test_list_endpoints_from_dict(): test_list_endpoints(request_type=dict) +def test_list_endpoints_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + client.list_endpoints() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.ListEndpointsRequest() + + @pytest.mark.asyncio async def test_list_endpoints_async( transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest @@ -1259,6 +1313,22 @@ def test_update_endpoint_from_dict(): test_update_endpoint(request_type=dict) +def test_update_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + client.update_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UpdateEndpointRequest() + + @pytest.mark.asyncio async def test_update_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest @@ -1476,6 +1546,22 @@ def test_delete_endpoint_from_dict(): test_delete_endpoint(request_type=dict) +def test_delete_endpoint_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + client.delete_endpoint() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeleteEndpointRequest() + + @pytest.mark.asyncio async def test_delete_endpoint_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest @@ -1662,6 +1748,22 @@ def test_deploy_model_from_dict(): test_deploy_model(request_type=dict) +def test_deploy_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + client.deploy_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeployModelRequest() + + @pytest.mark.asyncio async def test_deploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest @@ -1906,6 +2008,22 @@ def test_undeploy_model_from_dict(): test_undeploy_model(request_type=dict) +def test_undeploy_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + client.undeploy_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UndeployModelRequest() + + @pytest.mark.asyncio async def test_undeploy_model_async( transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index 67b1c6830f..e230d9f4b8 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -117,20 +117,26 @@ def test__get_default_mtls_endpoint(): assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -def test_job_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [JobServiceClient, JobServiceAsyncClient,], +) +def test_job_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = JobServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) +@pytest.mark.parametrize( + "client_class", [JobServiceClient, JobServiceAsyncClient,], +) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( @@ -139,9 +145,11 @@ def test_job_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -506,6 +514,24 @@ def test_create_custom_job_from_dict(): test_create_custom_job(request_type=dict) +def test_create_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + client.create_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateCustomJobRequest() + + @pytest.mark.asyncio async def test_create_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest @@ -737,6 +763,22 @@ def test_get_custom_job_from_dict(): test_get_custom_job(request_type=dict) +def test_get_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + client.get_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetCustomJobRequest() + + @pytest.mark.asyncio async def test_get_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest @@ -938,6 +980,22 @@ def test_list_custom_jobs_from_dict(): test_list_custom_jobs(request_type=dict) +def test_list_custom_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + client.list_custom_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListCustomJobsRequest() + + @pytest.mark.asyncio async def test_list_custom_jobs_async( transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest @@ -1266,6 +1324,24 @@ def test_delete_custom_job_from_dict(): test_delete_custom_job(request_type=dict) +def test_delete_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + client.delete_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteCustomJobRequest() + + @pytest.mark.asyncio async def test_delete_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest @@ -1464,6 +1540,24 @@ def test_cancel_custom_job_from_dict(): test_cancel_custom_job(request_type=dict) +def test_cancel_custom_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + client.cancel_custom_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelCustomJobRequest() + + @pytest.mark.asyncio async def test_cancel_custom_job_async( transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest @@ -1685,6 +1779,24 @@ def test_create_data_labeling_job_from_dict(): test_create_data_labeling_job(request_type=dict) +def test_create_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + client.create_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_create_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -1959,6 +2071,24 @@ def test_get_data_labeling_job_from_dict(): test_get_data_labeling_job(request_type=dict) +def test_get_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + client.get_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_get_data_labeling_job_async( transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest @@ -2190,6 +2320,24 @@ def test_list_data_labeling_jobs_from_dict(): test_list_data_labeling_jobs(request_type=dict) +def test_list_data_labeling_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + client.list_data_labeling_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListDataLabelingJobsRequest() + + @pytest.mark.asyncio async def test_list_data_labeling_jobs_async( transport: str = "grpc_asyncio", @@ -2563,6 +2711,24 @@ def test_delete_data_labeling_job_from_dict(): test_delete_data_labeling_job(request_type=dict) +def test_delete_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + client.delete_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_delete_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -2762,6 +2928,24 @@ def test_cancel_data_labeling_job_from_dict(): test_cancel_data_labeling_job(request_type=dict) +def test_cancel_data_labeling_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + client.cancel_data_labeling_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelDataLabelingJobRequest() + + @pytest.mark.asyncio async def test_cancel_data_labeling_job_async( transport: str = "grpc_asyncio", @@ -2976,6 +3160,24 @@ def test_create_hyperparameter_tuning_job_from_dict(): test_create_hyperparameter_tuning_job(request_type=dict) +def test_create_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + client.create_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_create_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -3244,6 +3446,24 @@ def test_get_hyperparameter_tuning_job_from_dict(): test_get_hyperparameter_tuning_job(request_type=dict) +def test_get_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + client.get_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_get_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -3468,6 +3688,24 @@ def test_list_hyperparameter_tuning_jobs_from_dict(): test_list_hyperparameter_tuning_jobs(request_type=dict) +def test_list_hyperparameter_tuning_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + client.list_hyperparameter_tuning_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + + @pytest.mark.asyncio async def test_list_hyperparameter_tuning_jobs_async( transport: str = "grpc_asyncio", @@ -3858,6 +4096,24 @@ def test_delete_hyperparameter_tuning_job_from_dict(): test_delete_hyperparameter_tuning_job(request_type=dict) +def test_delete_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + client.delete_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_delete_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -4058,6 +4314,24 @@ def test_cancel_hyperparameter_tuning_job_from_dict(): test_cancel_hyperparameter_tuning_job(request_type=dict) +def test_cancel_hyperparameter_tuning_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + client.cancel_hyperparameter_tuning_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + + @pytest.mark.asyncio async def test_cancel_hyperparameter_tuning_job_async( transport: str = "grpc_asyncio", @@ -4268,6 +4542,24 @@ def test_create_batch_prediction_job_from_dict(): test_create_batch_prediction_job(request_type=dict) +def test_create_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + client.create_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_create_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -4530,6 +4822,24 @@ def test_get_batch_prediction_job_from_dict(): test_get_batch_prediction_job(request_type=dict) +def test_get_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + client.get_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_get_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -4750,6 +5060,24 @@ def test_list_batch_prediction_jobs_from_dict(): test_list_batch_prediction_jobs(request_type=dict) +def test_list_batch_prediction_jobs_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + client.list_batch_prediction_jobs() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListBatchPredictionJobsRequest() + + @pytest.mark.asyncio async def test_list_batch_prediction_jobs_async( transport: str = "grpc_asyncio", @@ -5127,6 +5455,24 @@ def test_delete_batch_prediction_job_from_dict(): test_delete_batch_prediction_job(request_type=dict) +def test_delete_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + client.delete_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_delete_batch_prediction_job_async( transport: str = "grpc_asyncio", @@ -5326,6 +5672,24 @@ def test_cancel_batch_prediction_job_from_dict(): test_cancel_batch_prediction_job(request_type=dict) +def test_cancel_batch_prediction_job_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + client.cancel_batch_prediction_job() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + @pytest.mark.asyncio async def test_cancel_batch_prediction_job_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 8594354c88..37ae2b65e8 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -94,21 +94,25 @@ def test__get_default_mtls_endpoint(): ) -def test_migration_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], +) +def test_migration_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = MigrationServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,], ) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -118,9 +122,11 @@ def test_migration_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -496,6 +502,24 @@ def test_search_migratable_resources_from_dict(): test_search_migratable_resources(request_type=dict) +def test_search_migratable_resources_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + client.search_migratable_resources() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.SearchMigratableResourcesRequest() + + @pytest.mark.asyncio async def test_search_migratable_resources_async( transport: str = "grpc_asyncio", @@ -879,6 +903,24 @@ def test_batch_migrate_resources_from_dict(): test_batch_migrate_resources(request_type=dict) +def test_batch_migrate_resources_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + client.batch_migrate_resources() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.BatchMigrateResourcesRequest() + + @pytest.mark.asyncio async def test_batch_migrate_resources_async( transport: str = "grpc_asyncio", @@ -1509,21 +1551,19 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - location = "mussel" - dataset = "winkle" + dataset = "mussel" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -1533,19 +1573,21 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "squid" - dataset = "clam" + project = "scallop" + location = "abalone" + dataset = "squid" - expected = "projects/{project}/datasets/{dataset}".format( - project=project, dataset=dataset, + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", + "project": "clam", + "location": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index 05bb815f3f..51cbd4583f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -101,20 +101,26 @@ def test__get_default_mtls_endpoint(): assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -def test_model_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [ModelServiceClient, ModelServiceAsyncClient,], +) +def test_model_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = ModelServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" -@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) +@pytest.mark.parametrize( + "client_class", [ModelServiceClient, ModelServiceAsyncClient,], +) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( @@ -123,9 +129,11 @@ def test_model_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -477,6 +485,22 @@ def test_upload_model_from_dict(): test_upload_model(request_type=dict) +def test_upload_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + client.upload_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UploadModelRequest() + + @pytest.mark.asyncio async def test_upload_model_async( transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest @@ -713,6 +737,22 @@ def test_get_model_from_dict(): test_get_model(request_type=dict) +def test_get_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + client.get_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelRequest() + + @pytest.mark.asyncio async def test_get_model_async( transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest @@ -943,6 +983,22 @@ def test_list_models_from_dict(): test_list_models(request_type=dict) +def test_list_models_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + client.list_models() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelsRequest() + + @pytest.mark.asyncio async def test_list_models_async( transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest @@ -1285,6 +1341,22 @@ def test_update_model_from_dict(): test_update_model(request_type=dict) +def test_update_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + client.update_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UpdateModelRequest() + + @pytest.mark.asyncio async def test_update_model_async( transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest @@ -1524,6 +1596,22 @@ def test_delete_model_from_dict(): test_delete_model(request_type=dict) +def test_delete_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + client.delete_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.DeleteModelRequest() + + @pytest.mark.asyncio async def test_delete_model_async( transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest @@ -1710,6 +1798,22 @@ def test_export_model_from_dict(): test_export_model(request_type=dict) +def test_export_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + client.export_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ExportModelRequest() + + @pytest.mark.asyncio async def test_export_model_async( transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest @@ -1935,6 +2039,24 @@ def test_get_model_evaluation_from_dict(): test_get_model_evaluation(request_type=dict) +def test_get_model_evaluation_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + client.get_model_evaluation() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationRequest() + + @pytest.mark.asyncio async def test_get_model_evaluation_async( transport: str = "grpc_asyncio", @@ -2149,6 +2271,24 @@ def test_list_model_evaluations_from_dict(): test_list_model_evaluations(request_type=dict) +def test_list_model_evaluations_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + client.list_model_evaluations() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationsRequest() + + @pytest.mark.asyncio async def test_list_model_evaluations_async( transport: str = "grpc_asyncio", @@ -2529,6 +2669,24 @@ def test_get_model_evaluation_slice_from_dict(): test_get_model_evaluation_slice(request_type=dict) +def test_get_model_evaluation_slice_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + client.get_model_evaluation_slice() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationSliceRequest() + + @pytest.mark.asyncio async def test_get_model_evaluation_slice_async( transport: str = "grpc_asyncio", @@ -2739,6 +2897,24 @@ def test_list_model_evaluation_slices_from_dict(): test_list_model_evaluation_slices(request_type=dict) +def test_list_model_evaluation_slices_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + client.list_model_evaluation_slices() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationSlicesRequest() + + @pytest.mark.asyncio async def test_list_model_evaluation_slices_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index 8135921566..d1d65aecbd 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -110,21 +110,25 @@ def test__get_default_mtls_endpoint(): ) -def test_pipeline_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], +) +def test_pipeline_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = PipelineServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,], ) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -134,9 +138,11 @@ def test_pipeline_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -520,6 +526,24 @@ def test_create_training_pipeline_from_dict(): test_create_training_pipeline(request_type=dict) +def test_create_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + client.create_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_create_training_pipeline_async( transport: str = "grpc_asyncio", @@ -764,6 +788,24 @@ def test_get_training_pipeline_from_dict(): test_get_training_pipeline(request_type=dict) +def test_get_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + client.get_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_get_training_pipeline_async( transport: str = "grpc_asyncio", @@ -981,6 +1023,24 @@ def test_list_training_pipelines_from_dict(): test_list_training_pipelines(request_type=dict) +def test_list_training_pipelines_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + client.list_training_pipelines() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + + @pytest.mark.asyncio async def test_list_training_pipelines_async( transport: str = "grpc_asyncio", @@ -1354,6 +1414,24 @@ def test_delete_training_pipeline_from_dict(): test_delete_training_pipeline(request_type=dict) +def test_delete_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + client.delete_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_delete_training_pipeline_async( transport: str = "grpc_asyncio", @@ -1553,6 +1631,24 @@ def test_cancel_training_pipeline_from_dict(): test_cancel_training_pipeline(request_type=dict) +def test_cancel_training_pipeline_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + client.cancel_training_pipeline() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + + @pytest.mark.asyncio async def test_cancel_training_pipeline_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index 311f03d5b7..ba5333c0fa 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -104,7 +104,7 @@ def test_prediction_service_client_from_service_account_info(): @pytest.mark.parametrize( - "client_class", [PredictionServiceClient, PredictionServiceAsyncClient,] + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient,], ) def test_prediction_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index c91839ff1a..879a0a69d5 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -97,21 +97,25 @@ def test__get_default_mtls_endpoint(): ) -def test_specialist_pool_service_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], +) +def test_specialist_pool_service_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = SpecialistPoolServiceClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,], ) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -121,9 +125,11 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "aiplatform.googleapis.com:443" @@ -506,6 +512,24 @@ def test_create_specialist_pool_from_dict(): test_create_specialist_pool(request_type=dict) +def test_create_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + client.create_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_create_specialist_pool_async( transport: str = "grpc_asyncio", @@ -753,6 +777,24 @@ def test_get_specialist_pool_from_dict(): test_get_specialist_pool(request_type=dict) +def test_get_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + client.get_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_get_specialist_pool_async( transport: str = "grpc_asyncio", @@ -986,6 +1028,24 @@ def test_list_specialist_pools_from_dict(): test_list_specialist_pools(request_type=dict) +def test_list_specialist_pools_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + client.list_specialist_pools() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + + @pytest.mark.asyncio async def test_list_specialist_pools_async( transport: str = "grpc_asyncio", @@ -1376,6 +1436,24 @@ def test_delete_specialist_pool_from_dict(): test_delete_specialist_pool(request_type=dict) +def test_delete_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + client.delete_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_delete_specialist_pool_async( transport: str = "grpc_asyncio", @@ -1588,6 +1666,24 @@ def test_update_specialist_pool_from_dict(): test_update_specialist_pool(request_type=dict) +def test_update_specialist_pool_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + client.update_specialist_pool() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + + @pytest.mark.asyncio async def test_update_specialist_pool_async( transport: str = "grpc_asyncio", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py new file mode 100644 index 0000000000..5f1aec70ab --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_vizier_service.py @@ -0,0 +1,4228 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.vizier_service import ( + VizierServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.vizier_service import VizierServiceClient +from google.cloud.aiplatform_v1beta1.services.vizier_service import pagers +from google.cloud.aiplatform_v1beta1.services.vizier_service import transports +from google.cloud.aiplatform_v1beta1.types import study +from google.cloud.aiplatform_v1beta1.types import study as gca_study +from google.cloud.aiplatform_v1beta1.types import vizier_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert VizierServiceClient._get_default_mtls_endpoint(None) is None + assert ( + VizierServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + VizierServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [VizierServiceClient, VizierServiceAsyncClient,], +) +def test_vizier_service_client_from_service_account_info(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize( + "client_class", [VizierServiceClient, VizierServiceAsyncClient,], +) +def test_vizier_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_vizier_service_client_get_transport_class(): + transport = VizierServiceClient.get_transport_class() + available_transports = [ + transports.VizierServiceGrpcTransport, + ] + assert transport in available_transports + + transport = VizierServiceClient.get_transport_class("grpc") + assert transport == transports.VizierServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + VizierServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceClient), +) +@mock.patch.object( + VizierServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceAsyncClient), +) +def test_vizier_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(VizierServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(VizierServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "true"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc", "false"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + VizierServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceClient), +) +@mock.patch.object( + VizierServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(VizierServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_vizier_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_vizier_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (VizierServiceClient, transports.VizierServiceGrpcTransport, "grpc"), + ( + VizierServiceAsyncClient, + transports.VizierServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_vizier_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_vizier_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = VizierServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_study( + transport: str = "grpc", request_type=vizier_service.CreateStudyRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_study.Study( + name="name_value", + display_name="display_name_value", + state=gca_study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + + response = client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CreateStudyRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_study.Study) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == gca_study.Study.State.ACTIVE + + assert response.inactive_reason == "inactive_reason_value" + + +def test_create_study_from_dict(): + test_create_study(request_type=dict) + + +def test_create_study_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_study), "__call__") as call: + client.create_study() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CreateStudyRequest() + + +@pytest.mark.asyncio +async def test_create_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CreateStudyRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_study.Study( + name="name_value", + display_name="display_name_value", + state=gca_study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + ) + + response = await client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CreateStudyRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_study.Study) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == gca_study.Study.State.ACTIVE + + assert response.inactive_reason == "inactive_reason_value" + + +@pytest.mark.asyncio +async def test_create_study_async_from_dict(): + await test_create_study_async(request_type=dict) + + +def test_create_study_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CreateStudyRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_study), "__call__") as call: + call.return_value = gca_study.Study() + + client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_study_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CreateStudyRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_study), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_study.Study()) + + await client.create_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_study_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_study.Study() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_study( + parent="parent_value", study=gca_study.Study(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].study == gca_study.Study(name="name_value") + + +def test_create_study_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_study( + vizier_service.CreateStudyRequest(), + parent="parent_value", + study=gca_study.Study(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_study_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_study.Study() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_study.Study()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_study( + parent="parent_value", study=gca_study.Study(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].study == gca_study.Study(name="name_value") + + +@pytest.mark.asyncio +async def test_create_study_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_study( + vizier_service.CreateStudyRequest(), + parent="parent_value", + study=gca_study.Study(name="name_value"), + ) + + +def test_get_study( + transport: str = "grpc", request_type=vizier_service.GetStudyRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Study( + name="name_value", + display_name="display_name_value", + state=study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + + response = client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.GetStudyRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, study.Study) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == study.Study.State.ACTIVE + + assert response.inactive_reason == "inactive_reason_value" + + +def test_get_study_from_dict(): + test_get_study(request_type=dict) + + +def test_get_study_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_study), "__call__") as call: + client.get_study() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.GetStudyRequest() + + +@pytest.mark.asyncio +async def test_get_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.GetStudyRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Study( + name="name_value", + display_name="display_name_value", + state=study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + ) + + response = await client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.GetStudyRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, study.Study) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == study.Study.State.ACTIVE + + assert response.inactive_reason == "inactive_reason_value" + + +@pytest.mark.asyncio +async def test_get_study_async_from_dict(): + await test_get_study_async(request_type=dict) + + +def test_get_study_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.GetStudyRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_study), "__call__") as call: + call.return_value = study.Study() + + client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_study_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.GetStudyRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_study), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) + + await client.get_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_study_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Study() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_study(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_study_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_study( + vizier_service.GetStudyRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_study_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Study() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_study(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_study_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_study( + vizier_service.GetStudyRequest(), name="name_value", + ) + + +def test_list_studies( + transport: str = "grpc", request_type=vizier_service.ListStudiesRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListStudiesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListStudiesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListStudiesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_studies_from_dict(): + test_list_studies(request_type=dict) + + +def test_list_studies_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + client.list_studies() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListStudiesRequest() + + +@pytest.mark.asyncio +async def test_list_studies_async( + transport: str = "grpc_asyncio", request_type=vizier_service.ListStudiesRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListStudiesResponse(next_page_token="next_page_token_value",) + ) + + response = await client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListStudiesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListStudiesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_studies_async_from_dict(): + await test_list_studies_async(request_type=dict) + + +def test_list_studies_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.ListStudiesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + call.return_value = vizier_service.ListStudiesResponse() + + client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_studies_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.ListStudiesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListStudiesResponse() + ) + + await client.list_studies(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_studies_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListStudiesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_studies(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_studies_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_studies( + vizier_service.ListStudiesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_studies_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListStudiesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListStudiesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_studies(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_studies_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_studies( + vizier_service.ListStudiesRequest(), parent="parent_value", + ) + + +def test_list_studies_pager(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", + ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), + vizier_service.ListStudiesResponse( + studies=[study.Study(),], next_page_token="ghi", + ), + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_studies(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, study.Study) for i in results) + + +def test_list_studies_pages(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_studies), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", + ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), + vizier_service.ListStudiesResponse( + studies=[study.Study(),], next_page_token="ghi", + ), + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(),], + ), + RuntimeError, + ) + pages = list(client.list_studies(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_studies_async_pager(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_studies), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", + ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), + vizier_service.ListStudiesResponse( + studies=[study.Study(),], next_page_token="ghi", + ), + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(),], + ), + RuntimeError, + ) + async_pager = await client.list_studies(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, study.Study) for i in responses) + + +@pytest.mark.asyncio +async def test_list_studies_async_pages(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_studies), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(), study.Study(),], + next_page_token="abc", + ), + vizier_service.ListStudiesResponse(studies=[], next_page_token="def",), + vizier_service.ListStudiesResponse( + studies=[study.Study(),], next_page_token="ghi", + ), + vizier_service.ListStudiesResponse( + studies=[study.Study(), study.Study(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_studies(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_study( + transport: str = "grpc", request_type=vizier_service.DeleteStudyRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.DeleteStudyRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_study_from_dict(): + test_delete_study(request_type=dict) + + +def test_delete_study_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + client.delete_study() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.DeleteStudyRequest() + + +@pytest.mark.asyncio +async def test_delete_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.DeleteStudyRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.DeleteStudyRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_study_async_from_dict(): + await test_delete_study_async(request_type=dict) + + +def test_delete_study_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.DeleteStudyRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + call.return_value = None + + client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_study_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.DeleteStudyRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.delete_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_study_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_study(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_study_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_study( + vizier_service.DeleteStudyRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_study_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_study(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_study_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_study( + vizier_service.DeleteStudyRequest(), name="name_value", + ) + + +def test_lookup_study( + transport: str = "grpc", request_type=vizier_service.LookupStudyRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Study( + name="name_value", + display_name="display_name_value", + state=study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + + response = client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.LookupStudyRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, study.Study) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == study.Study.State.ACTIVE + + assert response.inactive_reason == "inactive_reason_value" + + +def test_lookup_study_from_dict(): + test_lookup_study(request_type=dict) + + +def test_lookup_study_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + client.lookup_study() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.LookupStudyRequest() + + +@pytest.mark.asyncio +async def test_lookup_study_async( + transport: str = "grpc_asyncio", request_type=vizier_service.LookupStudyRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Study( + name="name_value", + display_name="display_name_value", + state=study.Study.State.ACTIVE, + inactive_reason="inactive_reason_value", + ) + ) + + response = await client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.LookupStudyRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, study.Study) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == study.Study.State.ACTIVE + + assert response.inactive_reason == "inactive_reason_value" + + +@pytest.mark.asyncio +async def test_lookup_study_async_from_dict(): + await test_lookup_study_async(request_type=dict) + + +def test_lookup_study_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.LookupStudyRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + call.return_value = study.Study() + + client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_lookup_study_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.LookupStudyRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) + + await client.lookup_study(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_lookup_study_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Study() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.lookup_study(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_lookup_study_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.lookup_study( + vizier_service.LookupStudyRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_lookup_study_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.lookup_study), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Study() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Study()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.lookup_study(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_lookup_study_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.lookup_study( + vizier_service.LookupStudyRequest(), parent="parent_value", + ) + + +def test_suggest_trials( + transport: str = "grpc", request_type=vizier_service.SuggestTrialsRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.SuggestTrialsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_suggest_trials_from_dict(): + test_suggest_trials(request_type=dict) + + +def test_suggest_trials_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + client.suggest_trials() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.SuggestTrialsRequest() + + +@pytest.mark.asyncio +async def test_suggest_trials_async( + transport: str = "grpc_asyncio", request_type=vizier_service.SuggestTrialsRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.SuggestTrialsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_suggest_trials_async_from_dict(): + await test_suggest_trials_async(request_type=dict) + + +def test_suggest_trials_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.SuggestTrialsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_suggest_trials_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.SuggestTrialsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.suggest_trials), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.suggest_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_trial( + transport: str = "grpc", request_type=vizier_service.CreateTrialRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + + response = client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CreateTrialRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +def test_create_trial_from_dict(): + test_create_trial(request_type=dict) + + +def test_create_trial_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + client.create_trial() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CreateTrialRequest() + + +@pytest.mark.asyncio +async def test_create_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CreateTrialRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + ) + + response = await client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CreateTrialRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +@pytest.mark.asyncio +async def test_create_trial_async_from_dict(): + await test_create_trial_async(request_type=dict) + + +def test_create_trial_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CreateTrialRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + call.return_value = study.Trial() + + client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_trial_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CreateTrialRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) + + await client.create_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_trial_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_trial( + parent="parent_value", trial=study.Trial(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].trial == study.Trial(name="name_value") + + +def test_create_trial_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_trial( + vizier_service.CreateTrialRequest(), + parent="parent_value", + trial=study.Trial(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_trial_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_trial( + parent="parent_value", trial=study.Trial(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].trial == study.Trial(name="name_value") + + +@pytest.mark.asyncio +async def test_create_trial_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_trial( + vizier_service.CreateTrialRequest(), + parent="parent_value", + trial=study.Trial(name="name_value"), + ) + + +def test_get_trial( + transport: str = "grpc", request_type=vizier_service.GetTrialRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + + response = client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.GetTrialRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +def test_get_trial_from_dict(): + test_get_trial(request_type=dict) + + +def test_get_trial_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + client.get_trial() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.GetTrialRequest() + + +@pytest.mark.asyncio +async def test_get_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.GetTrialRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + ) + + response = await client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.GetTrialRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +@pytest.mark.asyncio +async def test_get_trial_async_from_dict(): + await test_get_trial_async(request_type=dict) + + +def test_get_trial_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.GetTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + call.return_value = study.Trial() + + client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_trial_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.GetTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) + + await client.get_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_trial_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_trial(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_trial_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_trial( + vizier_service.GetTrialRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_trial_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_trial(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_trial_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_trial( + vizier_service.GetTrialRequest(), name="name_value", + ) + + +def test_list_trials( + transport: str = "grpc", request_type=vizier_service.ListTrialsRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListTrialsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListTrialsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListTrialsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_trials_from_dict(): + test_list_trials(request_type=dict) + + +def test_list_trials_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + client.list_trials() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListTrialsRequest() + + +@pytest.mark.asyncio +async def test_list_trials_async( + transport: str = "grpc_asyncio", request_type=vizier_service.ListTrialsRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListTrialsResponse(next_page_token="next_page_token_value",) + ) + + response = await client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListTrialsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTrialsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_trials_async_from_dict(): + await test_list_trials_async(request_type=dict) + + +def test_list_trials_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.ListTrialsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + call.return_value = vizier_service.ListTrialsResponse() + + client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_trials_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.ListTrialsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListTrialsResponse() + ) + + await client.list_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_trials_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListTrialsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_trials(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_trials_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_trials( + vizier_service.ListTrialsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_trials_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListTrialsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListTrialsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_trials(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_trials_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_trials( + vizier_service.ListTrialsRequest(), parent="parent_value", + ) + + +def test_list_trials_pager(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListTrialsResponse( + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", + ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), + vizier_service.ListTrialsResponse( + trials=[study.Trial(),], next_page_token="ghi", + ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_trials(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, study.Trial) for i in results) + + +def test_list_trials_pages(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_trials), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListTrialsResponse( + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", + ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), + vizier_service.ListTrialsResponse( + trials=[study.Trial(),], next_page_token="ghi", + ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), + RuntimeError, + ) + pages = list(client.list_trials(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_trials_async_pager(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_trials), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListTrialsResponse( + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", + ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), + vizier_service.ListTrialsResponse( + trials=[study.Trial(),], next_page_token="ghi", + ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), + RuntimeError, + ) + async_pager = await client.list_trials(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, study.Trial) for i in responses) + + +@pytest.mark.asyncio +async def test_list_trials_async_pages(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_trials), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + vizier_service.ListTrialsResponse( + trials=[study.Trial(), study.Trial(), study.Trial(),], + next_page_token="abc", + ), + vizier_service.ListTrialsResponse(trials=[], next_page_token="def",), + vizier_service.ListTrialsResponse( + trials=[study.Trial(),], next_page_token="ghi", + ), + vizier_service.ListTrialsResponse(trials=[study.Trial(), study.Trial(),],), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_trials(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_add_trial_measurement( + transport: str = "grpc", request_type=vizier_service.AddTrialMeasurementRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_trial_measurement), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + + response = client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.AddTrialMeasurementRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +def test_add_trial_measurement_from_dict(): + test_add_trial_measurement(request_type=dict) + + +def test_add_trial_measurement_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_trial_measurement), "__call__" + ) as call: + client.add_trial_measurement() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.AddTrialMeasurementRequest() + + +@pytest.mark.asyncio +async def test_add_trial_measurement_async( + transport: str = "grpc_asyncio", + request_type=vizier_service.AddTrialMeasurementRequest, +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_trial_measurement), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + ) + + response = await client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.AddTrialMeasurementRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +@pytest.mark.asyncio +async def test_add_trial_measurement_async_from_dict(): + await test_add_trial_measurement_async(request_type=dict) + + +def test_add_trial_measurement_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.AddTrialMeasurementRequest() + request.trial_name = "trial_name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_trial_measurement), "__call__" + ) as call: + call.return_value = study.Trial() + + client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_add_trial_measurement_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.AddTrialMeasurementRequest() + request.trial_name = "trial_name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.add_trial_measurement), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) + + await client.add_trial_measurement(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + + +def test_complete_trial( + transport: str = "grpc", request_type=vizier_service.CompleteTrialRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + + response = client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CompleteTrialRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +def test_complete_trial_from_dict(): + test_complete_trial(request_type=dict) + + +def test_complete_trial_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + client.complete_trial() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CompleteTrialRequest() + + +@pytest.mark.asyncio +async def test_complete_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.CompleteTrialRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + ) + + response = await client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CompleteTrialRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +@pytest.mark.asyncio +async def test_complete_trial_async_from_dict(): + await test_complete_trial_async(request_type=dict) + + +def test_complete_trial_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CompleteTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + call.return_value = study.Trial() + + client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_complete_trial_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CompleteTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.complete_trial), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) + + await client.complete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_trial( + transport: str = "grpc", request_type=vizier_service.DeleteTrialRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.DeleteTrialRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_trial_from_dict(): + test_delete_trial(request_type=dict) + + +def test_delete_trial_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + client.delete_trial() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.DeleteTrialRequest() + + +@pytest.mark.asyncio +async def test_delete_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.DeleteTrialRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.DeleteTrialRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_trial_async_from_dict(): + await test_delete_trial_async(request_type=dict) + + +def test_delete_trial_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.DeleteTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + call.return_value = None + + client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_trial_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.DeleteTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.delete_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_trial_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_trial(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_trial_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_trial( + vizier_service.DeleteTrialRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_trial_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_trial(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_trial_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_trial( + vizier_service.DeleteTrialRequest(), name="name_value", + ) + + +def test_check_trial_early_stopping_state( + transport: str = "grpc", + request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CheckTrialEarlyStoppingStateRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_check_trial_early_stopping_state_from_dict(): + test_check_trial_early_stopping_state(request_type=dict) + + +def test_check_trial_early_stopping_state_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: + client.check_trial_early_stopping_state() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CheckTrialEarlyStoppingStateRequest() + + +@pytest.mark.asyncio +async def test_check_trial_early_stopping_state_async( + transport: str = "grpc_asyncio", + request_type=vizier_service.CheckTrialEarlyStoppingStateRequest, +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.CheckTrialEarlyStoppingStateRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_check_trial_early_stopping_state_async_from_dict(): + await test_check_trial_early_stopping_state_async(request_type=dict) + + +def test_check_trial_early_stopping_state_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CheckTrialEarlyStoppingStateRequest() + request.trial_name = "trial_name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_check_trial_early_stopping_state_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.CheckTrialEarlyStoppingStateRequest() + request.trial_name = "trial_name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.check_trial_early_stopping_state), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.check_trial_early_stopping_state(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "trial_name=trial_name/value",) in kw["metadata"] + + +def test_stop_trial( + transport: str = "grpc", request_type=vizier_service.StopTrialRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + + response = client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.StopTrialRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +def test_stop_trial_from_dict(): + test_stop_trial(request_type=dict) + + +def test_stop_trial_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + client.stop_trial() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.StopTrialRequest() + + +@pytest.mark.asyncio +async def test_stop_trial_async( + transport: str = "grpc_asyncio", request_type=vizier_service.StopTrialRequest +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + study.Trial( + name="name_value", + id="id_value", + state=study.Trial.State.REQUESTED, + custom_job="custom_job_value", + ) + ) + + response = await client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.StopTrialRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, study.Trial) + + assert response.name == "name_value" + + assert response.id == "id_value" + + assert response.state == study.Trial.State.REQUESTED + + assert response.custom_job == "custom_job_value" + + +@pytest.mark.asyncio +async def test_stop_trial_async_from_dict(): + await test_stop_trial_async(request_type=dict) + + +def test_stop_trial_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.StopTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + call.return_value = study.Trial() + + client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_stop_trial_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.StopTrialRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_trial), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(study.Trial()) + + await client.stop_trial(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_list_optimal_trials( + transport: str = "grpc", request_type=vizier_service.ListOptimalTrialsRequest +): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_optimal_trials), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListOptimalTrialsResponse() + + response = client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListOptimalTrialsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, vizier_service.ListOptimalTrialsResponse) + + +def test_list_optimal_trials_from_dict(): + test_list_optimal_trials(request_type=dict) + + +def test_list_optimal_trials_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_optimal_trials), "__call__" + ) as call: + client.list_optimal_trials() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListOptimalTrialsRequest() + + +@pytest.mark.asyncio +async def test_list_optimal_trials_async( + transport: str = "grpc_asyncio", + request_type=vizier_service.ListOptimalTrialsRequest, +): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_optimal_trials), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListOptimalTrialsResponse() + ) + + response = await client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == vizier_service.ListOptimalTrialsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, vizier_service.ListOptimalTrialsResponse) + + +@pytest.mark.asyncio +async def test_list_optimal_trials_async_from_dict(): + await test_list_optimal_trials_async(request_type=dict) + + +def test_list_optimal_trials_field_headers(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.ListOptimalTrialsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_optimal_trials), "__call__" + ) as call: + call.return_value = vizier_service.ListOptimalTrialsResponse() + + client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_optimal_trials_field_headers_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = vizier_service.ListOptimalTrialsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_optimal_trials), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListOptimalTrialsResponse() + ) + + await client.list_optimal_trials(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_optimal_trials_flattened(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_optimal_trials), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListOptimalTrialsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_optimal_trials(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_optimal_trials_flattened_error(): + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_optimal_trials( + vizier_service.ListOptimalTrialsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_optimal_trials_flattened_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_optimal_trials), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = vizier_service.ListOptimalTrialsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + vizier_service.ListOptimalTrialsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_optimal_trials(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_optimal_trials_flattened_error_async(): + client = VizierServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_optimal_trials( + vizier_service.ListOptimalTrialsRequest(), parent="parent_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.VizierServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.VizierServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = VizierServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.VizierServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = VizierServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.VizierServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = VizierServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.VizierServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.VizierServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = VizierServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.VizierServiceGrpcTransport,) + + +def test_vizier_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.VizierServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_vizier_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.VizierServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_study", + "get_study", + "list_studies", + "delete_study", + "lookup_study", + "suggest_trials", + "create_trial", + "get_trial", + "list_trials", + "add_trial_measurement", + "complete_trial", + "delete_trial", + "check_trial_early_stopping_state", + "stop_trial", + "list_optimal_trials", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_vizier_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.VizierServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_vizier_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.vizier_service.transports.VizierServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.VizierServiceTransport() + adc.assert_called_once() + + +def test_vizier_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + VizierServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_vizier_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.VizierServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) +def test_vizier_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_vizier_service_host_no_port(): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_vizier_service_host_with_port(): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_vizier_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.VizierServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_vizier_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.VizierServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) +def test_vizier_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.VizierServiceGrpcTransport, + transports.VizierServiceGrpcAsyncIOTransport, + ], +) +def test_vizier_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_vizier_service_grpc_lro_client(): + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_vizier_service_grpc_lro_async_client(): + client = VizierServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_custom_job_path(): + project = "squid" + location = "clam" + custom_job = "whelk" + + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + actual = VizierServiceClient.custom_job_path(project, location, custom_job) + assert expected == actual + + +def test_parse_custom_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "custom_job": "nudibranch", + } + path = VizierServiceClient.custom_job_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_custom_job_path(path) + assert expected == actual + + +def test_study_path(): + project = "cuttlefish" + location = "mussel" + study = "winkle" + + expected = "projects/{project}/locations/{location}/studies/{study}".format( + project=project, location=location, study=study, + ) + actual = VizierServiceClient.study_path(project, location, study) + assert expected == actual + + +def test_parse_study_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "study": "abalone", + } + path = VizierServiceClient.study_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_study_path(path) + assert expected == actual + + +def test_trial_path(): + project = "squid" + location = "clam" + study = "whelk" + trial = "octopus" + + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) + actual = VizierServiceClient.trial_path(project, location, study, trial) + assert expected == actual + + +def test_parse_trial_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", + } + path = VizierServiceClient.trial_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_trial_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "winkle" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = VizierServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nautilus", + } + path = VizierServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "scallop" + + expected = "folders/{folder}".format(folder=folder,) + actual = VizierServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "abalone", + } + path = VizierServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "squid" + + expected = "organizations/{organization}".format(organization=organization,) + actual = VizierServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "clam", + } + path = VizierServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "whelk" + + expected = "projects/{project}".format(project=project,) + actual = VizierServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "octopus", + } + path = VizierServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "oyster" + location = "nudibranch" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = VizierServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "cuttlefish", + "location": "mussel", + } + path = VizierServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = VizierServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.VizierServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = VizierServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.VizierServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = VizierServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) From e79d5a49b481c1719e6ecd1d5fd43691f8c5c004 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Mon, 22 Mar 2021 12:53:54 -0700 Subject: [PATCH 7/7] chore: release 0.6.0 (#268) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 12 ++++++++++++ setup.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ea9ca7a7b3..be2d9a602f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [0.6.0](https://www.github.com/googleapis/python-aiplatform/compare/v0.5.1...v0.6.0) (2021-03-22) + + +### Features + +* add Vizier service ([#266](https://www.github.com/googleapis/python-aiplatform/issues/266)) ([e5c1b1a](https://www.github.com/googleapis/python-aiplatform/commit/e5c1b1a4909d701efeb27f29af43a95516c51475)) + + +### Bug Fixes + +* skip create data labeling job sample tests ([#254](https://www.github.com/googleapis/python-aiplatform/issues/254)) ([116a29b](https://www.github.com/googleapis/python-aiplatform/commit/116a29b1efcebb15bad14c3c36d3591c09ef10be)) + ### [0.5.1](https://www.github.com/googleapis/python-aiplatform/compare/v0.5.0...v0.5.1) (2021-03-01) diff --git a/setup.py b/setup.py index a290702738..cc19d7a867 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ import setuptools # type: ignore name = "google-cloud-aiplatform" -version = "0.5.1" +version = "0.6.0" description = "Cloud AI Platform API client library" package_root = os.path.abspath(os.path.dirname(__file__))