From 7ec92b71be9c1d0d305421bb1b1dce5d92377bba Mon Sep 17 00:00:00 2001 From: Yoshi Automation Bot Date: Mon, 16 Nov 2020 10:15:26 -0800 Subject: [PATCH] feat: add common resource paths, expose client transport (#87) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * changes without context autosynth cannot find the source of changes triggered by earlier changes in this repository, or by version upgrades to tools such as linters. * chore(py-library): enable snippet-bot Co-authored-by: Benjamin E. Coe Source-Author: Takashi Matsuo Source-Date: Tue Sep 1 17:14:08 2020 +0000 Source-Repo: googleapis/synthtool Source-Sha: d91dd8aac77f7a9c5506c238038a26fa4f9e361e Source-Link: https://github.com/googleapis/synthtool/commit/d91dd8aac77f7a9c5506c238038a26fa4f9e361e * chore(py-library): update decrypt secrets file * chore(py-library): update decrypt secrets file From https://github.com/GoogleCloudPlatform/python-docs-samples/blob/master/scripts/decrypt-secrets.sh * docs: explain conditional Co-authored-by: Jeffrey Rennie Source-Author: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Source-Date: Tue Sep 8 11:35:59 2020 -0600 Source-Repo: googleapis/synthtool Source-Sha: d302f93d7f47e2852e585ac35ab2d15585717ec0 Source-Link: https://github.com/googleapis/synthtool/commit/d302f93d7f47e2852e585ac35ab2d15585717ec0 * chore(python-library): use sphinx 1.5.5 for the docfx job Originally tested at: https://github.com/googleapis/python-texttospeech/pull/89 This change will fix the missing docstring in the yaml files. Source-Author: Takashi Matsuo Source-Date: Thu Sep 10 04:12:14 2020 +0000 Source-Repo: googleapis/synthtool Source-Sha: ffcee7952b74f647cbb3ef021d95422f10816fca Source-Link: https://github.com/googleapis/synthtool/commit/ffcee7952b74f647cbb3ef021d95422f10816fca * build(python): use release-publish app for notifying GitHub of releas… * build(python): use release-publish app for notifying GitHub of release status * fix: re-add pypi password Source-Author: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Source-Date: Wed Sep 16 08:46:42 2020 -0600 Source-Repo: googleapis/synthtool Source-Sha: 257fda18168bedb76985024bd198ed1725485488 Source-Link: https://github.com/googleapis/synthtool/commit/257fda18168bedb76985024bd198ed1725485488 * build(python): add secret manager in kokoro Source-Author: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Source-Date: Wed Sep 16 10:24:40 2020 -0600 Source-Repo: googleapis/synthtool Source-Sha: dba48bb9bc6959c232bec9150ac6313b608fe7bd Source-Link: https://github.com/googleapis/synthtool/commit/dba48bb9bc6959c232bec9150ac6313b608fe7bd * chore(python): add sphinx doctest extension Source-Author: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Source-Date: Mon Sep 21 13:09:57 2020 -0600 Source-Repo: googleapis/synthtool Source-Sha: 27f4406999b1eee29e04b09b2423a8e4646c7e24 Source-Link: https://github.com/googleapis/synthtool/commit/27f4406999b1eee29e04b09b2423a8e4646c7e24 * chore(python): remove note about editable installs `pip install -e .` is supported and is how we install the library for tests. Source-Author: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Source-Date: Tue Sep 22 12:06:12 2020 -0600 Source-Repo: googleapis/synthtool Source-Sha: a651c5fb763c69a921aecdd3e1d8dc51dbf20f8d Source-Link: https://github.com/googleapis/synthtool/commit/a651c5fb763c69a921aecdd3e1d8dc51dbf20f8d * chore(python): use BUILD_SPECIFIC_GCLOUD_PROJECT for samples https://github.com/googleapis/python-talent/blob/ef045e8eb348db36d7a2a611e6f26b11530d273b/samples/snippets/noxfile_config.py#L27-L32 `BUILD_SPECIFIC_GCLOUD_PROJECT` is an alternate project used for sample tests that do poorly with concurrent runs on the same project. Source-Author: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Source-Date: Wed Sep 30 13:06:03 2020 -0600 Source-Repo: googleapis/synthtool Source-Sha: 9b0da5204ab90bcc36f8cd4e5689eff1a54cc3e4 Source-Link: https://github.com/googleapis/synthtool/commit/9b0da5204ab90bcc36f8cd4e5689eff1a54cc3e4 * chore(python): use 'setup.py' to detect repo root Closes #792 Source-Author: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com> Source-Date: Fri Oct 9 15:06:33 2020 -0600 Source-Repo: googleapis/synthtool Source-Sha: e0ae456852bf22f38796deb79cff30b516fde244 Source-Link: https://github.com/googleapis/synthtool/commit/e0ae456852bf22f38796deb79cff30b516fde244 * build(python): samples tests should pass if no samples exist Source-Author: Daniel Sanche Source-Date: Wed Oct 14 08:00:06 2020 -0700 Source-Repo: googleapis/synthtool Source-Sha: 477764cc4ee6db346d3febef2bb1ea0abf27de52 Source-Link: https://github.com/googleapis/synthtool/commit/477764cc4ee6db346d3febef2bb1ea0abf27de52 * chore(python_library): change the docs bucket name Source-Author: Takashi Matsuo Source-Date: Fri Oct 16 09:58:05 2020 -0700 Source-Repo: googleapis/synthtool Source-Sha: da5c6050d13b4950c82666a81d8acd25157664ae Source-Link: https://github.com/googleapis/synthtool/commit/da5c6050d13b4950c82666a81d8acd25157664ae * chore(docs): update code of conduct of synthtool and templates Source-Author: Christopher Wilcox Source-Date: Thu Oct 22 14:22:01 2020 -0700 Source-Repo: googleapis/synthtool Source-Sha: 5f6ef0ec5501d33c4667885b37a7685a30d41a76 Source-Link: https://github.com/googleapis/synthtool/commit/5f6ef0ec5501d33c4667885b37a7685a30d41a76 * docs: add proto-plus to intersphinx mapping Source-Author: Tim Swast Source-Date: Tue Oct 27 12:01:14 2020 -0500 Source-Repo: googleapis/synthtool Source-Sha: ea52b8a0bd560f72f376efcf45197fb7c8869120 Source-Link: https://github.com/googleapis/synthtool/commit/ea52b8a0bd560f72f376efcf45197fb7c8869120 * fix(python_library): fix external unit test dependencies I recently submitted https://github.com/googleapis/synthtool/pull/811/files, allowing external dependencies for unit tests. This fixes a small missing comma bug Source-Author: Daniel Sanche Source-Date: Thu Oct 29 16:58:01 2020 -0700 Source-Repo: googleapis/synthtool Source-Sha: 6542bd723403513626f61642fc02ddca528409aa Source-Link: https://github.com/googleapis/synthtool/commit/6542bd723403513626f61642fc02ddca528409aa * chore: add type hint check Source-Author: Leah E. Cole <6719667+leahecole@users.noreply.github.com> Source-Date: Wed Nov 4 17:36:32 2020 -0800 Source-Repo: googleapis/synthtool Source-Sha: 3d3e94c4e02370f307a9a200b0c743c3d8d19f29 Source-Link: https://github.com/googleapis/synthtool/commit/3d3e94c4e02370f307a9a200b0c743c3d8d19f29 * chore: add blacken to template Source-Author: Leah E. Cole <6719667+leahecole@users.noreply.github.com> Source-Date: Thu Nov 5 15:22:03 2020 -0800 Source-Repo: googleapis/synthtool Source-Sha: 1f1148d3c7a7a52f0c98077f976bd9b3c948ee2b Source-Link: https://github.com/googleapis/synthtool/commit/1f1148d3c7a7a52f0c98077f976bd9b3c948ee2b * fix: address lint issues Source-Author: Leah E. Cole <6719667+leahecole@users.noreply.github.com> Source-Date: Thu Nov 12 11:30:49 2020 -0800 Source-Repo: googleapis/synthtool Source-Sha: e89175cf074dccc4babb4eca66ae913696e47a71 Source-Link: https://github.com/googleapis/synthtool/commit/e89175cf074dccc4babb4eca66ae913696e47a71 --- .github/snippet-bot.yml | 0 .kokoro/docs/common.cfg | 2 +- .kokoro/populate-secrets.sh | 43 + .kokoro/release/common.cfg | 50 +- .kokoro/samples/python3.6/common.cfg | 6 + .kokoro/samples/python3.7/common.cfg | 6 + .kokoro/samples/python3.8/common.cfg | 6 + .kokoro/test-samples.sh | 8 +- .kokoro/trampoline.sh | 15 +- CODE_OF_CONDUCT.md | 123 ++- CONTRIBUTING.rst | 19 - docs/conf.py | 4 +- docs/dataproc_v1/types.rst | 1 + docs/dataproc_v1beta2/types.rst | 1 + google/cloud/dataproc_v1/__init__.py | 4 +- .../async_client.py | 107 ++- .../autoscaling_policy_service/client.py | 169 +++- .../transports/base.py | 32 +- .../transports/grpc.py | 72 +- .../transports/grpc_asyncio.py | 65 +- .../cluster_controller/async_client.py | 109 ++- .../services/cluster_controller/client.py | 173 +++- .../cluster_controller/transports/base.py | 32 +- .../cluster_controller/transports/grpc.py | 72 +- .../transports/grpc_asyncio.py | 65 +- .../services/job_controller/async_client.py | 111 ++- .../services/job_controller/client.py | 173 +++- .../job_controller/transports/base.py | 36 +- .../job_controller/transports/grpc.py | 72 +- .../job_controller/transports/grpc_asyncio.py | 65 +- .../workflow_template_service/async_client.py | 124 ++- .../workflow_template_service/client.py | 178 +++- .../transports/base.py | 34 +- .../transports/grpc.py | 72 +- .../transports/grpc_asyncio.py | 65 +- .../dataproc_v1/types/autoscaling_policies.py | 8 +- google/cloud/dataproc_v1/types/clusters.py | 6 +- google/cloud/dataproc_v1/types/jobs.py | 58 +- google/cloud/dataproc_v1/types/operations.py | 4 +- .../dataproc_v1/types/workflow_templates.py | 10 +- .../async_client.py | 107 ++- .../autoscaling_policy_service/client.py | 169 +++- .../transports/base.py | 32 +- .../transports/grpc.py | 72 +- .../transports/grpc_asyncio.py | 65 +- .../cluster_controller/async_client.py | 109 ++- .../services/cluster_controller/client.py | 173 +++- .../cluster_controller/transports/base.py | 32 +- .../cluster_controller/transports/grpc.py | 72 +- .../transports/grpc_asyncio.py | 65 +- .../services/job_controller/async_client.py | 111 ++- .../services/job_controller/client.py | 173 +++- .../job_controller/transports/base.py | 36 +- .../job_controller/transports/grpc.py | 72 +- .../job_controller/transports/grpc_asyncio.py | 65 +- .../workflow_template_service/async_client.py | 124 ++- .../workflow_template_service/client.py | 178 +++- .../transports/base.py | 34 +- .../transports/grpc.py | 72 +- .../transports/grpc_asyncio.py | 65 +- .../types/autoscaling_policies.py | 8 +- .../cloud/dataproc_v1beta2/types/clusters.py | 6 +- google/cloud/dataproc_v1beta2/types/jobs.py | 58 +- .../dataproc_v1beta2/types/operations.py | 4 +- .../types/workflow_templates.py | 10 +- noxfile.py | 10 +- samples/snippets/noxfile.py | 24 +- scripts/decrypt-secrets.sh | 15 +- scripts/fixup_dataproc_v1_keywords.py | 1 + scripts/fixup_dataproc_v1beta2_keywords.py | 1 + synth.metadata | 197 +++- .../test_autoscaling_policy_service.py | 791 ++++++++++------ .../dataproc_v1/test_cluster_controller.py | 824 ++++++++++------- .../gapic/dataproc_v1/test_job_controller.py | 824 ++++++++++------- .../test_workflow_template_service.py | 847 +++++++++++------- .../test_autoscaling_policy_service.py | 791 ++++++++++------ .../test_cluster_controller.py | 824 ++++++++++------- .../dataproc_v1beta2/test_job_controller.py | 824 ++++++++++------- .../test_workflow_template_service.py | 847 +++++++++++------- 79 files changed, 7352 insertions(+), 3580 deletions(-) create mode 100644 .github/snippet-bot.yml create mode 100755 .kokoro/populate-secrets.sh diff --git a/.github/snippet-bot.yml b/.github/snippet-bot.yml new file mode 100644 index 00000000..e69de29b diff --git a/.kokoro/docs/common.cfg b/.kokoro/docs/common.cfg index 676322b2..41eeb4d7 100644 --- a/.kokoro/docs/common.cfg +++ b/.kokoro/docs/common.cfg @@ -30,7 +30,7 @@ env_vars: { env_vars: { key: "V2_STAGING_BUCKET" - value: "docs-staging-v2-staging" + value: "docs-staging-v2" } # It will upload the docker image after successful builds. diff --git a/.kokoro/populate-secrets.sh b/.kokoro/populate-secrets.sh new file mode 100755 index 00000000..f5251425 --- /dev/null +++ b/.kokoro/populate-secrets.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Copyright 2020 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +function now { date +"%Y-%m-%d %H:%M:%S" | tr -d '\n' ;} +function msg { println "$*" >&2 ;} +function println { printf '%s\n' "$(now) $*" ;} + + +# Populates requested secrets set in SECRET_MANAGER_KEYS from service account: +# kokoro-trampoline@cloud-devrel-kokoro-resources.iam.gserviceaccount.com +SECRET_LOCATION="${KOKORO_GFILE_DIR}/secret_manager" +msg "Creating folder on disk for secrets: ${SECRET_LOCATION}" +mkdir -p ${SECRET_LOCATION} +for key in $(echo ${SECRET_MANAGER_KEYS} | sed "s/,/ /g") +do + msg "Retrieving secret ${key}" + docker run --entrypoint=gcloud \ + --volume=${KOKORO_GFILE_DIR}:${KOKORO_GFILE_DIR} \ + gcr.io/google.com/cloudsdktool/cloud-sdk \ + secrets versions access latest \ + --project cloud-devrel-kokoro-resources \ + --secret ${key} > \ + "${SECRET_LOCATION}/${key}" + if [[ $? == 0 ]]; then + msg "Secret written to ${SECRET_LOCATION}/${key}" + else + msg "Error retrieving secret ${key}" + fi +done diff --git a/.kokoro/release/common.cfg b/.kokoro/release/common.cfg index 4bf136c7..9c364d34 100644 --- a/.kokoro/release/common.cfg +++ b/.kokoro/release/common.cfg @@ -23,42 +23,18 @@ env_vars: { value: "github/python-dataproc/.kokoro/release.sh" } -# Fetch the token needed for reporting release status to GitHub -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73713 - keyname: "yoshi-automation-github-key" - } - } -} - -# Fetch PyPI password -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73713 - keyname: "google_cloud_pypi_password" - } - } -} - -# Fetch magictoken to use with Magic Github Proxy -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73713 - keyname: "releasetool-magictoken" - } - } +# Fetch PyPI password +before_action { + fetch_keystore { + keystore_resource { + keystore_config_id: 73713 + keyname: "google_cloud_pypi_password" + } + } } -# Fetch api key to use with Magic Github Proxy -before_action { - fetch_keystore { - keystore_resource { - keystore_config_id: 73713 - keyname: "magic-github-proxy-api-key" - } - } -} +# Tokens needed to report release status back to GitHub +env_vars: { + key: "SECRET_MANAGER_KEYS" + value: "releasetool-publish-reporter-app,releasetool-publish-reporter-googleapis-installation,releasetool-publish-reporter-pem" +} \ No newline at end of file diff --git a/.kokoro/samples/python3.6/common.cfg b/.kokoro/samples/python3.6/common.cfg index c04328ca..c0dd20b7 100644 --- a/.kokoro/samples/python3.6/common.cfg +++ b/.kokoro/samples/python3.6/common.cfg @@ -13,6 +13,12 @@ env_vars: { value: "py-3.6" } +# Declare build specific Cloud project. +env_vars: { + key: "BUILD_SPECIFIC_GCLOUD_PROJECT" + value: "python-docs-samples-tests-py36" +} + env_vars: { key: "TRAMPOLINE_BUILD_FILE" value: "github/python-dataproc/.kokoro/test-samples.sh" diff --git a/.kokoro/samples/python3.7/common.cfg b/.kokoro/samples/python3.7/common.cfg index fd45b1ae..e6abf059 100644 --- a/.kokoro/samples/python3.7/common.cfg +++ b/.kokoro/samples/python3.7/common.cfg @@ -13,6 +13,12 @@ env_vars: { value: "py-3.7" } +# Declare build specific Cloud project. +env_vars: { + key: "BUILD_SPECIFIC_GCLOUD_PROJECT" + value: "python-docs-samples-tests-py37" +} + env_vars: { key: "TRAMPOLINE_BUILD_FILE" value: "github/python-dataproc/.kokoro/test-samples.sh" diff --git a/.kokoro/samples/python3.8/common.cfg b/.kokoro/samples/python3.8/common.cfg index 30105713..2ede3de9 100644 --- a/.kokoro/samples/python3.8/common.cfg +++ b/.kokoro/samples/python3.8/common.cfg @@ -13,6 +13,12 @@ env_vars: { value: "py-3.8" } +# Declare build specific Cloud project. +env_vars: { + key: "BUILD_SPECIFIC_GCLOUD_PROJECT" + value: "python-docs-samples-tests-py38" +} + env_vars: { key: "TRAMPOLINE_BUILD_FILE" value: "github/python-dataproc/.kokoro/test-samples.sh" diff --git a/.kokoro/test-samples.sh b/.kokoro/test-samples.sh index 028210c0..1e12cb1e 100755 --- a/.kokoro/test-samples.sh +++ b/.kokoro/test-samples.sh @@ -28,6 +28,12 @@ if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then git checkout $LATEST_RELEASE fi +# Exit early if samples directory doesn't exist +if [ ! -d "./samples" ]; then + echo "No tests run. `./samples` not found" + exit 0 +fi + # Disable buffering, so that the logs stream through. export PYTHONUNBUFFERED=1 @@ -101,4 +107,4 @@ cd "$ROOT" # Workaround for Kokoro permissions issue: delete secrets rm testing/{test-env.sh,client-secrets.json,service-account.json} -exit "$RTN" \ No newline at end of file +exit "$RTN" diff --git a/.kokoro/trampoline.sh b/.kokoro/trampoline.sh index e8c4251f..f39236e9 100755 --- a/.kokoro/trampoline.sh +++ b/.kokoro/trampoline.sh @@ -15,9 +15,14 @@ set -eo pipefail -python3 "${KOKORO_GFILE_DIR}/trampoline_v1.py" || ret_code=$? +# Always run the cleanup script, regardless of the success of bouncing into +# the container. +function cleanup() { + chmod +x ${KOKORO_GFILE_DIR}/trampoline_cleanup.sh + ${KOKORO_GFILE_DIR}/trampoline_cleanup.sh + echo "cleanup"; +} +trap cleanup EXIT -chmod +x ${KOKORO_GFILE_DIR}/trampoline_cleanup.sh -${KOKORO_GFILE_DIR}/trampoline_cleanup.sh || true - -exit ${ret_code} +$(dirname $0)/populate-secrets.sh # Secret Manager secrets. +python3 "${KOKORO_GFILE_DIR}/trampoline_v1.py" \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index b3d1f602..039f4368 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,44 +1,95 @@ -# Contributor Code of Conduct +# Code of Conduct -As contributors and maintainers of this project, -and in the interest of fostering an open and welcoming community, -we pledge to respect all people who contribute through reporting issues, -posting feature requests, updating documentation, -submitting pull requests or patches, and other activities. +## Our Pledge -We are committed to making participation in this project -a harassment-free experience for everyone, -regardless of level of experience, gender, gender identity and expression, -sexual orientation, disability, personal appearance, -body size, race, ethnicity, age, religion, or nationality. +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of +experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery -* Personal attacks -* Trolling or insulting/derogatory comments -* Public or private harassment -* Publishing other's private information, -such as physical or electronic -addresses, without explicit permission -* Other unethical or unprofessional conduct. +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions -that are not aligned to this Code of Conduct. -By adopting this Code of Conduct, -project maintainers commit themselves to fairly and consistently -applying these principles to every aspect of managing this project. -Project maintainers who do not follow or enforce the Code of Conduct -may be permanently removed from the project team. - -This code of conduct applies both within project spaces and in public spaces -when an individual is representing the project or its community. - -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported by opening an issue -or contacting one or more of the project maintainers. - -This Code of Conduct is adapted from the [Contributor Covenant](http://contributor-covenant.org), version 1.2.0, -available at [http://contributor-covenant.org/version/1/2/0/](http://contributor-covenant.org/version/1/2/0/) +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, or to ban temporarily or permanently any +contributor for other behaviors that they deem inappropriate, threatening, +offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when the Project +Steward has a reasonable belief that an individual's behavior may have a +negative impact on the project or its community. + +## Conflict Resolution + +We do not believe that all conflict is bad; healthy debate and disagreement +often yield positive results. However, it is never okay to be disrespectful or +to engage in behavior that violates the project’s code of conduct. + +If you see someone violating the code of conduct, you are encouraged to address +the behavior directly with those involved. Many issues can be resolved quickly +and easily, and this gives people more control over the outcome of their +dispute. If you are unable to resolve the matter for any reason, or if the +behavior is threatening or harassing, report it. We are dedicated to providing +an environment where participants feel welcome and safe. + + +Reports should be directed to *googleapis-stewards@google.com*, the +Project Steward(s) for *Google Cloud Client Libraries*. It is the Project Steward’s duty to +receive and address reported violations of the code of conduct. They will then +work with a committee consisting of representatives from the Open Source +Programs Office and the Google Open Source Strategy team. If for any reason you +are uncomfortable reaching out to the Project Steward, please email +opensource@google.com. + +We will investigate every complaint, but you may not receive a direct response. +We will use our discretion in determining when and how to follow up on reported +incidents, which may range from not taking action to permanent expulsion from +the project and project-sponsored spaces. We will notify the accused of the +report and provide them an opportunity to discuss it before any action is taken. +The identity of the reporter will be omitted from the details of the report +supplied to the accused. In potentially harmful situations, such as ongoing +harassment or threats to anyone's safety, we may take action without notice. + +## Attribution + +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, +available at +https://www.contributor-covenant.org/version/1/4/code-of-conduct.html \ No newline at end of file diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 3759323e..43384e74 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -80,25 +80,6 @@ We use `nox `__ to instrument our tests. .. nox: https://pypi.org/project/nox/ -Note on Editable Installs / Develop Mode -======================================== - -- As mentioned previously, using ``setuptools`` in `develop mode`_ - or a ``pip`` `editable install`_ is not possible with this - library. This is because this library uses `namespace packages`_. - For context see `Issue #2316`_ and the relevant `PyPA issue`_. - - Since ``editable`` / ``develop`` mode can't be used, packages - need to be installed directly. Hence your changes to the source - tree don't get incorporated into the **already installed** - package. - -.. _namespace packages: https://www.python.org/dev/peps/pep-0420/ -.. _Issue #2316: https://github.com/GoogleCloudPlatform/google-cloud-python/issues/2316 -.. _PyPA issue: https://github.com/pypa/packaging-problems/issues/12 -.. _develop mode: https://setuptools.readthedocs.io/en/latest/setuptools.html#development-mode -.. _editable install: https://pip.pypa.io/en/stable/reference/pip_install/#editable-installs - ***************************************** I'm getting weird errors... Can you help? ***************************************** diff --git a/docs/conf.py b/docs/conf.py index 476a3a7e..9c46a0ba 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,7 +29,7 @@ # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = "1.6.3" +needs_sphinx = "1.5.5" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -39,6 +39,7 @@ "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.coverage", + "sphinx.ext.doctest", "sphinx.ext.napoleon", "sphinx.ext.todo", "sphinx.ext.viewcode", @@ -348,6 +349,7 @@ "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,), "grpc": ("https://grpc.io/grpc/python/", None), + "proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None), } diff --git a/docs/dataproc_v1/types.rst b/docs/dataproc_v1/types.rst index 5cd2ad4b..5dde0cd6 100644 --- a/docs/dataproc_v1/types.rst +++ b/docs/dataproc_v1/types.rst @@ -3,3 +3,4 @@ Types for Google Cloud Dataproc v1 API .. automodule:: google.cloud.dataproc_v1.types :members: + :show-inheritance: diff --git a/docs/dataproc_v1beta2/types.rst b/docs/dataproc_v1beta2/types.rst index e0972271..e3dba489 100644 --- a/docs/dataproc_v1beta2/types.rst +++ b/docs/dataproc_v1beta2/types.rst @@ -3,3 +3,4 @@ Types for Google Cloud Dataproc v1beta2 API .. automodule:: google.cloud.dataproc_v1beta2.types :members: + :show-inheritance: diff --git a/google/cloud/dataproc_v1/__init__.py b/google/cloud/dataproc_v1/__init__.py index 82d780ab..35887dcc 100644 --- a/google/cloud/dataproc_v1/__init__.py +++ b/google/cloud/dataproc_v1/__init__.py @@ -147,7 +147,6 @@ "InstantiateInlineWorkflowTemplateRequest", "InstantiateWorkflowTemplateRequest", "Job", - "JobControllerClient", "JobMetadata", "JobPlacement", "JobReference", @@ -192,6 +191,7 @@ "WorkflowNode", "WorkflowTemplate", "WorkflowTemplatePlacement", - "YarnApplication", "WorkflowTemplateServiceClient", + "YarnApplication", + "JobControllerClient", ) diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/async_client.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/async_client.py index e1eeca2f..fa91a7e7 100644 --- a/google/cloud/dataproc_v1/services/autoscaling_policy_service/async_client.py +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/async_client.py @@ -31,7 +31,7 @@ from google.cloud.dataproc_v1.services.autoscaling_policy_service import pagers from google.cloud.dataproc_v1.types import autoscaling_policies -from .transports.base import AutoscalingPolicyServiceTransport +from .transports.base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import AutoscalingPolicyServiceGrpcAsyncIOTransport from .client import AutoscalingPolicyServiceClient @@ -49,10 +49,55 @@ class AutoscalingPolicyServiceAsyncClient: autoscaling_policy_path = staticmethod( AutoscalingPolicyServiceClient.autoscaling_policy_path ) + parse_autoscaling_policy_path = staticmethod( + AutoscalingPolicyServiceClient.parse_autoscaling_policy_path + ) + + common_billing_account_path = staticmethod( + AutoscalingPolicyServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(AutoscalingPolicyServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + AutoscalingPolicyServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod( + AutoscalingPolicyServiceClient.common_project_path + ) + parse_common_project_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod( + AutoscalingPolicyServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_location_path + ) from_service_account_file = AutoscalingPolicyServiceClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> AutoscalingPolicyServiceTransport: + """Return the transport used by the client instance. + + Returns: + AutoscalingPolicyServiceTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(AutoscalingPolicyServiceClient).get_transport_class, type(AutoscalingPolicyServiceClient), @@ -64,6 +109,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, AutoscalingPolicyServiceTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the autoscaling policy service client. @@ -79,16 +125,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -96,7 +145,10 @@ def __init__( """ self._client = AutoscalingPolicyServiceClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def create_autoscaling_policy( @@ -154,7 +206,8 @@ async def create_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, policy]): + has_flattened_params = any([parent, policy]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -175,7 +228,7 @@ async def create_autoscaling_policy( rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -230,7 +283,8 @@ async def update_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([policy]): + has_flattened_params = any([policy]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -253,11 +307,11 @@ async def update_autoscaling_policy( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -321,7 +375,8 @@ async def get_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -344,11 +399,11 @@ async def get_autoscaling_policy( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -413,7 +468,8 @@ async def list_autoscaling_policies( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -436,11 +492,11 @@ async def list_autoscaling_policies( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -508,7 +564,8 @@ async def delete_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -527,7 +584,7 @@ async def delete_autoscaling_policy( rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -543,11 +600,11 @@ async def delete_autoscaling_policy( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("AutoscalingPolicyServiceAsyncClient",) diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py index 8bf745b2..1551d1f4 100644 --- a/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py @@ -16,24 +16,26 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore from google.cloud.dataproc_v1.services.autoscaling_policy_service import pagers from google.cloud.dataproc_v1.types import autoscaling_policies -from .transports.base import AutoscalingPolicyServiceTransport +from .transports.base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import AutoscalingPolicyServiceGrpcTransport from .transports.grpc_asyncio import AutoscalingPolicyServiceGrpcAsyncIOTransport @@ -132,6 +134,15 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> AutoscalingPolicyServiceTransport: + """Return the transport used by the client instance. + + Returns: + AutoscalingPolicyServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod def autoscaling_policy_path( project: str, location: str, autoscaling_policy: str, @@ -150,12 +161,72 @@ def parse_autoscaling_policy_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, AutoscalingPolicyServiceTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, AutoscalingPolicyServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the autoscaling policy service client. @@ -168,48 +239,74 @@ def __init__( transport (Union[str, ~.AutoscalingPolicyServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -233,11 +330,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_autoscaling_policy( @@ -678,11 +775,11 @@ def delete_autoscaling_policy( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("AutoscalingPolicyServiceClient",) diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/base.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/base.py index 0c609a8b..46201f4b 100644 --- a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/base.py +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -30,11 +30,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class AutoscalingPolicyServiceTransport(abc.ABC): @@ -50,6 +50,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -67,6 +68,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -94,15 +100,15 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_autoscaling_policy: gapic_v1.method.wrap_method( self.create_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.update_autoscaling_policy: gapic_v1.method.wrap_method( self.update_autoscaling_policy, @@ -111,11 +117,11 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.get_autoscaling_policy: gapic_v1.method.wrap_method( self.get_autoscaling_policy, @@ -124,11 +130,11 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.list_autoscaling_policies: gapic_v1.method.wrap_method( self.list_autoscaling_policies, @@ -137,16 +143,16 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.delete_autoscaling_policy: gapic_v1.method.wrap_method( self.delete_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc.py index d4fbfe1b..2d2e2746 100644 --- a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc.py +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc.py @@ -15,20 +15,21 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1.types import autoscaling_policies from google.protobuf import empty_pb2 as empty # type: ignore -from .base import AutoscalingPolicyServiceTransport +from .base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO class AutoscalingPolicyServiceGrpcTransport(AutoscalingPolicyServiceTransport): @@ -57,7 +58,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -76,16 +79,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -93,6 +103,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -100,7 +112,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -131,6 +149,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -141,6 +177,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -151,7 +188,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -185,24 +222,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc_asyncio.py b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc_asyncio.py index 1eb47af4..bd9d91a0 100644 --- a/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc_asyncio.py @@ -15,9 +15,12 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -27,7 +30,7 @@ from google.cloud.dataproc_v1.types import autoscaling_policies from google.protobuf import empty_pb2 as empty # type: ignore -from .base import AutoscalingPolicyServiceTransport +from .base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO from .grpc import AutoscalingPolicyServiceGrpcTransport @@ -99,7 +102,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -119,16 +124,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -136,6 +148,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -143,13 +157,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -169,6 +194,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -177,6 +220,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -188,13 +232,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1/services/cluster_controller/async_client.py b/google/cloud/dataproc_v1/services/cluster_controller/async_client.py index 1ea1637c..75a59d2d 100644 --- a/google/cloud/dataproc_v1/services/cluster_controller/async_client.py +++ b/google/cloud/dataproc_v1/services/cluster_controller/async_client.py @@ -28,15 +28,15 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1.services.cluster_controller import pagers from google.cloud.dataproc_v1.types import clusters from google.cloud.dataproc_v1.types import operations from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from .transports.base import ClusterControllerTransport +from .transports.base import ClusterControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import ClusterControllerGrpcAsyncIOTransport from .client import ClusterControllerClient @@ -51,9 +51,47 @@ class ClusterControllerAsyncClient: DEFAULT_ENDPOINT = ClusterControllerClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = ClusterControllerClient.DEFAULT_MTLS_ENDPOINT + common_billing_account_path = staticmethod( + ClusterControllerClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ClusterControllerClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(ClusterControllerClient.common_folder_path) + parse_common_folder_path = staticmethod( + ClusterControllerClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + ClusterControllerClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + ClusterControllerClient.parse_common_organization_path + ) + + common_project_path = staticmethod(ClusterControllerClient.common_project_path) + parse_common_project_path = staticmethod( + ClusterControllerClient.parse_common_project_path + ) + + common_location_path = staticmethod(ClusterControllerClient.common_location_path) + parse_common_location_path = staticmethod( + ClusterControllerClient.parse_common_location_path + ) + from_service_account_file = ClusterControllerClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> ClusterControllerTransport: + """Return the transport used by the client instance. + + Returns: + ClusterControllerTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(ClusterControllerClient).get_transport_class, type(ClusterControllerClient) ) @@ -64,6 +102,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, ClusterControllerTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the cluster controller client. @@ -79,16 +118,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -96,7 +138,10 @@ def __init__( """ self._client = ClusterControllerClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def create_cluster( @@ -156,7 +201,8 @@ async def create_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster]): + has_flattened_params = any([project_id, region, cluster]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -185,7 +231,7 @@ async def create_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -331,9 +377,10 @@ async def update_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any( + has_flattened_params = any( [project_id, region, cluster_name, cluster, update_mask] - ): + ) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -366,7 +413,7 @@ async def update_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -451,7 +498,8 @@ async def delete_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster_name]): + has_flattened_params = any([project_id, region, cluster_name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -480,7 +528,7 @@ async def delete_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -550,7 +598,8 @@ async def get_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster_name]): + has_flattened_params = any([project_id, region, cluster_name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -577,13 +626,13 @@ async def get_cluster( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -669,7 +718,8 @@ async def list_clusters( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, filter]): + has_flattened_params = any([project_id, region, filter]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -696,13 +746,13 @@ async def list_clusters( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -778,7 +828,8 @@ async def diagnose_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster_name]): + has_flattened_params = any([project_id, region, cluster_name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -807,7 +858,7 @@ async def diagnose_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -826,11 +877,11 @@ async def diagnose_cluster( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("ClusterControllerAsyncClient",) diff --git a/google/cloud/dataproc_v1/services/cluster_controller/client.py b/google/cloud/dataproc_v1/services/cluster_controller/client.py index 7f895f39..42594c47 100644 --- a/google/cloud/dataproc_v1/services/cluster_controller/client.py +++ b/google/cloud/dataproc_v1/services/cluster_controller/client.py @@ -16,29 +16,31 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1.services.cluster_controller import pagers from google.cloud.dataproc_v1.types import clusters from google.cloud.dataproc_v1.types import operations from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from .transports.base import ClusterControllerTransport +from .transports.base import ClusterControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import ClusterControllerGrpcTransport from .transports.grpc_asyncio import ClusterControllerGrpcAsyncIOTransport @@ -137,12 +139,81 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> ClusterControllerTransport: + """Return the transport used by the client instance. + + Returns: + ClusterControllerTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, ClusterControllerTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ClusterControllerTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the cluster controller client. @@ -155,48 +226,74 @@ def __init__( transport (Union[str, ~.ClusterControllerTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -220,11 +317,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_cluster( @@ -922,11 +1019,11 @@ def diagnose_cluster( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("ClusterControllerClient",) diff --git a/google/cloud/dataproc_v1/services/cluster_controller/transports/base.py b/google/cloud/dataproc_v1/services/cluster_controller/transports/base.py index 993de639..caccd04e 100644 --- a/google/cloud/dataproc_v1/services/cluster_controller/transports/base.py +++ b/google/cloud/dataproc_v1/services/cluster_controller/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -31,11 +31,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class ClusterControllerTransport(abc.ABC): @@ -51,6 +51,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -68,6 +69,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -95,9 +101,9 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_cluster: gapic_v1.method.wrap_method( @@ -109,7 +115,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.update_cluster: gapic_v1.method.wrap_method( self.update_cluster, @@ -120,7 +126,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.delete_cluster: gapic_v1.method.wrap_method( self.delete_cluster, @@ -131,7 +137,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.get_cluster: gapic_v1.method.wrap_method( self.get_cluster, @@ -140,13 +146,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.list_clusters: gapic_v1.method.wrap_method( self.list_clusters, @@ -155,13 +161,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.diagnose_cluster: gapic_v1.method.wrap_method( self.diagnose_cluster, @@ -172,7 +178,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc.py b/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc.py index 46f0b416..7c8b83c2 100644 --- a/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc.py +++ b/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc.py @@ -15,21 +15,22 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1.types import clusters from google.longrunning import operations_pb2 as operations # type: ignore -from .base import ClusterControllerTransport +from .base import ClusterControllerTransport, DEFAULT_CLIENT_INFO class ClusterControllerGrpcTransport(ClusterControllerTransport): @@ -58,7 +59,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -77,16 +80,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -94,6 +104,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -101,7 +113,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -132,6 +150,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -142,6 +178,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -152,7 +189,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -186,24 +223,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc_asyncio.py b/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc_asyncio.py index e8e49b6e..96c998f6 100644 --- a/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1/services/cluster_controller/transports/grpc_asyncio.py @@ -15,10 +15,13 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -28,7 +31,7 @@ from google.cloud.dataproc_v1.types import clusters from google.longrunning import operations_pb2 as operations # type: ignore -from .base import ClusterControllerTransport +from .base import ClusterControllerTransport, DEFAULT_CLIENT_INFO from .grpc import ClusterControllerGrpcTransport @@ -100,7 +103,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -120,16 +125,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -137,6 +149,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -144,13 +158,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -170,6 +195,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -178,6 +221,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -189,13 +233,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1/services/job_controller/async_client.py b/google/cloud/dataproc_v1/services/job_controller/async_client.py index ed4b2e02..8eaf753e 100644 --- a/google/cloud/dataproc_v1/services/job_controller/async_client.py +++ b/google/cloud/dataproc_v1/services/job_controller/async_client.py @@ -28,12 +28,12 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1.services.job_controller import pagers from google.cloud.dataproc_v1.types import jobs -from .transports.base import JobControllerTransport +from .transports.base import JobControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import JobControllerGrpcAsyncIOTransport from .client import JobControllerClient @@ -46,9 +46,47 @@ class JobControllerAsyncClient: DEFAULT_ENDPOINT = JobControllerClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = JobControllerClient.DEFAULT_MTLS_ENDPOINT + common_billing_account_path = staticmethod( + JobControllerClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + JobControllerClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(JobControllerClient.common_folder_path) + parse_common_folder_path = staticmethod( + JobControllerClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + JobControllerClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + JobControllerClient.parse_common_organization_path + ) + + common_project_path = staticmethod(JobControllerClient.common_project_path) + parse_common_project_path = staticmethod( + JobControllerClient.parse_common_project_path + ) + + common_location_path = staticmethod(JobControllerClient.common_location_path) + parse_common_location_path = staticmethod( + JobControllerClient.parse_common_location_path + ) + from_service_account_file = JobControllerClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> JobControllerTransport: + """Return the transport used by the client instance. + + Returns: + JobControllerTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(JobControllerClient).get_transport_class, type(JobControllerClient) ) @@ -59,6 +97,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, JobControllerTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the job controller client. @@ -74,16 +113,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -91,7 +133,10 @@ def __init__( """ self._client = JobControllerClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def submit_job( @@ -142,7 +187,8 @@ async def submit_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job]): + has_flattened_params = any([project_id, region, job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -171,7 +217,7 @@ async def submit_job( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -232,7 +278,8 @@ async def submit_job_as_operation( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job]): + has_flattened_params = any([project_id, region, job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -261,7 +308,7 @@ async def submit_job_as_operation( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -328,7 +375,8 @@ async def get_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job_id]): + has_flattened_params = any([project_id, region, job_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -355,13 +403,13 @@ async def get_job( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -438,7 +486,8 @@ async def list_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, filter]): + has_flattened_params = any([project_id, region, filter]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -465,13 +514,13 @@ async def list_jobs( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -525,7 +574,7 @@ async def update_job( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -586,7 +635,8 @@ async def cancel_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job_id]): + has_flattened_params = any([project_id, region, job_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -613,13 +663,13 @@ async def cancel_job( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -673,7 +723,8 @@ async def delete_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job_id]): + has_flattened_params = any([project_id, region, job_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -702,7 +753,7 @@ async def delete_job( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -712,11 +763,11 @@ async def delete_job( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("JobControllerAsyncClient",) diff --git a/google/cloud/dataproc_v1/services/job_controller/client.py b/google/cloud/dataproc_v1/services/job_controller/client.py index 157d913e..d101e833 100644 --- a/google/cloud/dataproc_v1/services/job_controller/client.py +++ b/google/cloud/dataproc_v1/services/job_controller/client.py @@ -16,26 +16,28 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1.services.job_controller import pagers from google.cloud.dataproc_v1.types import jobs -from .transports.base import JobControllerTransport +from .transports.base import JobControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import JobControllerGrpcTransport from .transports.grpc_asyncio import JobControllerGrpcAsyncIOTransport @@ -128,12 +130,81 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> JobControllerTransport: + """Return the transport used by the client instance. + + Returns: + JobControllerTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, JobControllerTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobControllerTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the job controller client. @@ -146,48 +217,74 @@ def __init__( transport (Union[str, ~.JobControllerTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -211,11 +308,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def submit_job( @@ -795,11 +892,11 @@ def delete_job( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("JobControllerClient",) diff --git a/google/cloud/dataproc_v1/services/job_controller/transports/base.py b/google/cloud/dataproc_v1/services/job_controller/transports/base.py index d4200ffd..c8538dd1 100644 --- a/google/cloud/dataproc_v1/services/job_controller/transports/base.py +++ b/google/cloud/dataproc_v1/services/job_controller/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -32,11 +32,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class JobControllerTransport(abc.ABC): @@ -52,6 +52,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -69,6 +70,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -96,9 +102,9 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.submit_job: gapic_v1.method.wrap_method( @@ -110,7 +116,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.submit_job_as_operation: gapic_v1.method.wrap_method( self.submit_job_as_operation, @@ -121,7 +127,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.get_job: gapic_v1.method.wrap_method( self.get_job, @@ -130,13 +136,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.list_jobs: gapic_v1.method.wrap_method( self.list_jobs, @@ -145,13 +151,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.update_job: gapic_v1.method.wrap_method( self.update_job, @@ -162,7 +168,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.cancel_job: gapic_v1.method.wrap_method( self.cancel_job, @@ -171,13 +177,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.delete_job: gapic_v1.method.wrap_method( self.delete_job, @@ -188,7 +194,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1/services/job_controller/transports/grpc.py b/google/cloud/dataproc_v1/services/job_controller/transports/grpc.py index 6174c16b..504c8d4d 100644 --- a/google/cloud/dataproc_v1/services/job_controller/transports/grpc.py +++ b/google/cloud/dataproc_v1/services/job_controller/transports/grpc.py @@ -15,22 +15,23 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1.types import jobs from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import JobControllerTransport +from .base import JobControllerTransport, DEFAULT_CLIENT_INFO class JobControllerGrpcTransport(JobControllerTransport): @@ -58,7 +59,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -77,16 +80,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -94,6 +104,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -101,7 +113,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -132,6 +150,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -142,6 +178,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -152,7 +189,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -186,24 +223,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1/services/job_controller/transports/grpc_asyncio.py b/google/cloud/dataproc_v1/services/job_controller/transports/grpc_asyncio.py index 999141bf..6c60c089 100644 --- a/google/cloud/dataproc_v1/services/job_controller/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1/services/job_controller/transports/grpc_asyncio.py @@ -15,10 +15,13 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -29,7 +32,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import JobControllerTransport +from .base import JobControllerTransport, DEFAULT_CLIENT_INFO from .grpc import JobControllerGrpcTransport @@ -100,7 +103,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -120,16 +125,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -137,6 +149,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -144,13 +158,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -170,6 +195,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -178,6 +221,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -189,13 +233,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/async_client.py b/google/cloud/dataproc_v1/services/workflow_template_service/async_client.py index dc5c0f22..44cb69b3 100644 --- a/google/cloud/dataproc_v1/services/workflow_template_service/async_client.py +++ b/google/cloud/dataproc_v1/services/workflow_template_service/async_client.py @@ -28,14 +28,14 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1.services.workflow_template_service import pagers from google.cloud.dataproc_v1.types import workflow_templates from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import WorkflowTemplateServiceTransport +from .transports.base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import WorkflowTemplateServiceGrpcAsyncIOTransport from .client import WorkflowTemplateServiceClient @@ -53,10 +53,55 @@ class WorkflowTemplateServiceAsyncClient: workflow_template_path = staticmethod( WorkflowTemplateServiceClient.workflow_template_path ) + parse_workflow_template_path = staticmethod( + WorkflowTemplateServiceClient.parse_workflow_template_path + ) + + common_billing_account_path = staticmethod( + WorkflowTemplateServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(WorkflowTemplateServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + WorkflowTemplateServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod( + WorkflowTemplateServiceClient.common_project_path + ) + parse_common_project_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod( + WorkflowTemplateServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_location_path + ) from_service_account_file = WorkflowTemplateServiceClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> WorkflowTemplateServiceTransport: + """Return the transport used by the client instance. + + Returns: + WorkflowTemplateServiceTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(WorkflowTemplateServiceClient).get_transport_class, type(WorkflowTemplateServiceClient), @@ -68,6 +113,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, WorkflowTemplateServiceTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the workflow template service client. @@ -83,16 +129,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -100,7 +149,10 @@ def __init__( """ self._client = WorkflowTemplateServiceClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def create_workflow_template( @@ -157,7 +209,8 @@ async def create_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, template]): + has_flattened_params = any([parent, template]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -184,7 +237,7 @@ async def create_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -249,7 +302,8 @@ async def get_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -272,13 +326,13 @@ async def get_workflow_template( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -389,7 +443,8 @@ async def instantiate_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, parameters]): + has_flattened_params = any([name, parameters]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -402,8 +457,9 @@ async def instantiate_workflow_template( if name is not None: request.name = name - if parameters is not None: - request.parameters = parameters + + if parameters: + request.parameters.update(parameters) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -416,7 +472,7 @@ async def instantiate_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -535,7 +591,8 @@ async def instantiate_inline_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, template]): + has_flattened_params = any([parent, template]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -562,7 +619,7 @@ async def instantiate_inline_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -626,7 +683,8 @@ async def update_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([template]): + has_flattened_params = any([template]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -651,7 +709,7 @@ async def update_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -719,7 +777,8 @@ async def list_workflow_templates( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -742,13 +801,13 @@ async def list_workflow_templates( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -814,7 +873,8 @@ async def delete_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -839,7 +899,7 @@ async def delete_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -855,11 +915,11 @@ async def delete_workflow_template( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("WorkflowTemplateServiceAsyncClient",) diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/client.py b/google/cloud/dataproc_v1/services/workflow_template_service/client.py index 9b380a9a..73a5626b 100644 --- a/google/cloud/dataproc_v1/services/workflow_template_service/client.py +++ b/google/cloud/dataproc_v1/services/workflow_template_service/client.py @@ -16,28 +16,30 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1.services.workflow_template_service import pagers from google.cloud.dataproc_v1.types import workflow_templates from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import WorkflowTemplateServiceTransport +from .transports.base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import WorkflowTemplateServiceGrpcTransport from .transports.grpc_asyncio import WorkflowTemplateServiceGrpcAsyncIOTransport @@ -136,6 +138,15 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> WorkflowTemplateServiceTransport: + """Return the transport used by the client instance. + + Returns: + WorkflowTemplateServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod def workflow_template_path( project: str, region: str, workflow_template: str, @@ -154,12 +165,72 @@ def parse_workflow_template_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, WorkflowTemplateServiceTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, WorkflowTemplateServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the workflow template service client. @@ -172,48 +243,74 @@ def __init__( transport (Union[str, ~.WorkflowTemplateServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -237,11 +334,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_workflow_template( @@ -539,8 +636,9 @@ def instantiate_workflow_template( if name is not None: request.name = name - if parameters is not None: - request.parameters = parameters + + if parameters: + request.parameters.update(parameters) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -970,11 +1068,11 @@ def delete_workflow_template( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("WorkflowTemplateServiceClient",) diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/transports/base.py b/google/cloud/dataproc_v1/services/workflow_template_service/transports/base.py index a1bc72b0..967002f5 100644 --- a/google/cloud/dataproc_v1/services/workflow_template_service/transports/base.py +++ b/google/cloud/dataproc_v1/services/workflow_template_service/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -32,11 +32,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class WorkflowTemplateServiceTransport(abc.ABC): @@ -52,6 +52,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -69,6 +70,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -96,9 +102,9 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_workflow_template: gapic_v1.method.wrap_method( @@ -110,7 +116,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.get_workflow_template: gapic_v1.method.wrap_method( self.get_workflow_template, @@ -119,13 +125,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.instantiate_workflow_template: gapic_v1.method.wrap_method( self.instantiate_workflow_template, @@ -136,7 +142,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.instantiate_inline_workflow_template: gapic_v1.method.wrap_method( self.instantiate_inline_workflow_template, @@ -147,7 +153,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.update_workflow_template: gapic_v1.method.wrap_method( self.update_workflow_template, @@ -158,7 +164,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.list_workflow_templates: gapic_v1.method.wrap_method( self.list_workflow_templates, @@ -167,13 +173,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.delete_workflow_template: gapic_v1.method.wrap_method( self.delete_workflow_template, @@ -184,7 +190,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc.py b/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc.py index 5a9e8b61..36e8cb35 100644 --- a/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc.py +++ b/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc.py @@ -15,22 +15,23 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1.types import workflow_templates from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import WorkflowTemplateServiceTransport +from .base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO class WorkflowTemplateServiceGrpcTransport(WorkflowTemplateServiceTransport): @@ -59,7 +60,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -78,16 +81,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -95,6 +105,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -102,7 +114,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -133,6 +151,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -143,6 +179,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -153,7 +190,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -187,24 +224,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc_asyncio.py b/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc_asyncio.py index 9e3f6355..c69eaf0f 100644 --- a/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc_asyncio.py @@ -15,10 +15,13 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -29,7 +32,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import WorkflowTemplateServiceTransport +from .base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO from .grpc import WorkflowTemplateServiceGrpcTransport @@ -101,7 +104,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -121,16 +126,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -138,6 +150,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -145,13 +159,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -171,6 +196,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -179,6 +222,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -190,13 +234,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1/types/autoscaling_policies.py b/google/cloud/dataproc_v1/types/autoscaling_policies.py index 136dc3f7..edd3806a 100644 --- a/google/cloud/dataproc_v1/types/autoscaling_policies.py +++ b/google/cloud/dataproc_v1/types/autoscaling_policies.py @@ -250,7 +250,7 @@ class CreateAutoscalingPolicyRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - policy = proto.Field(proto.MESSAGE, number=2, message=AutoscalingPolicy,) + policy = proto.Field(proto.MESSAGE, number=2, message="AutoscalingPolicy",) class GetAutoscalingPolicyRequest(proto.Message): @@ -282,7 +282,7 @@ class UpdateAutoscalingPolicyRequest(proto.Message): Required. The updated autoscaling policy. """ - policy = proto.Field(proto.MESSAGE, number=1, message=AutoscalingPolicy,) + policy = proto.Field(proto.MESSAGE, number=1, message="AutoscalingPolicy",) class DeleteAutoscalingPolicyRequest(proto.Message): @@ -357,7 +357,9 @@ class ListAutoscalingPoliciesResponse(proto.Message): def raw_page(self): return self - policies = proto.RepeatedField(proto.MESSAGE, number=1, message=AutoscalingPolicy,) + policies = proto.RepeatedField( + proto.MESSAGE, number=1, message="AutoscalingPolicy", + ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/dataproc_v1/types/clusters.py b/google/cloud/dataproc_v1/types/clusters.py index 50c0f5ee..9cb83872 100644 --- a/google/cloud/dataproc_v1/types/clusters.py +++ b/google/cloud/dataproc_v1/types/clusters.py @@ -903,7 +903,7 @@ class CreateClusterRequest(proto.Message): region = proto.Field(proto.STRING, number=3) - cluster = proto.Field(proto.MESSAGE, number=2, message=Cluster,) + cluster = proto.Field(proto.MESSAGE, number=2, message="Cluster",) request_id = proto.Field(proto.STRING, number=4) @@ -1018,7 +1018,7 @@ class UpdateClusterRequest(proto.Message): cluster_name = proto.Field(proto.STRING, number=2) - cluster = proto.Field(proto.MESSAGE, number=3, message=Cluster,) + cluster = proto.Field(proto.MESSAGE, number=3, message="Cluster",) graceful_decommission_timeout = proto.Field( proto.MESSAGE, number=6, message=duration.Duration, @@ -1161,7 +1161,7 @@ class ListClustersResponse(proto.Message): def raw_page(self): return self - clusters = proto.RepeatedField(proto.MESSAGE, number=1, message=Cluster,) + clusters = proto.RepeatedField(proto.MESSAGE, number=1, message="Cluster",) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/dataproc_v1/types/jobs.py b/google/cloud/dataproc_v1/types/jobs.py index c4456e0b..84c0e3f6 100644 --- a/google/cloud/dataproc_v1/types/jobs.py +++ b/google/cloud/dataproc_v1/types/jobs.py @@ -145,7 +145,7 @@ class HadoopJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=7) - logging_config = proto.Field(proto.MESSAGE, number=8, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=8, message="LoggingConfig",) class SparkJob(proto.Message): @@ -203,7 +203,7 @@ class SparkJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=7) - logging_config = proto.Field(proto.MESSAGE, number=8, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=8, message="LoggingConfig",) class PySparkJob(proto.Message): @@ -263,7 +263,7 @@ class PySparkJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=7) - logging_config = proto.Field(proto.MESSAGE, number=8, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=8, message="LoggingConfig",) class QueryList(proto.Message): @@ -326,7 +326,7 @@ class HiveJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) continue_on_failure = proto.Field(proto.BOOL, number=3) @@ -368,7 +368,7 @@ class SparkSqlJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) script_variables = proto.MapField(proto.STRING, proto.STRING, number=3) @@ -377,7 +377,7 @@ class SparkSqlJob(proto.Message): jar_file_uris = proto.RepeatedField(proto.STRING, number=56) - logging_config = proto.Field(proto.MESSAGE, number=6, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=6, message="LoggingConfig",) class PigJob(proto.Message): @@ -415,7 +415,7 @@ class PigJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) continue_on_failure = proto.Field(proto.BOOL, number=3) @@ -426,7 +426,7 @@ class PigJob(proto.Message): jar_file_uris = proto.RepeatedField(proto.STRING, number=6) - logging_config = proto.Field(proto.MESSAGE, number=7, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=7, message="LoggingConfig",) class SparkRJob(proto.Message): @@ -475,7 +475,7 @@ class SparkRJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=5) - logging_config = proto.Field(proto.MESSAGE, number=6, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=6, message="LoggingConfig",) class PrestoJob(proto.Message): @@ -515,7 +515,7 @@ class PrestoJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) continue_on_failure = proto.Field(proto.BOOL, number=3) @@ -526,7 +526,7 @@ class PrestoJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=6) - logging_config = proto.Field(proto.MESSAGE, number=7, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=7, message="LoggingConfig",) class JobPlacement(proto.Message): @@ -738,44 +738,46 @@ class Job(proto.Message): will indicate if it was successful, failed, or cancelled. """ - reference = proto.Field(proto.MESSAGE, number=1, message=JobReference,) + reference = proto.Field(proto.MESSAGE, number=1, message="JobReference",) - placement = proto.Field(proto.MESSAGE, number=2, message=JobPlacement,) + placement = proto.Field(proto.MESSAGE, number=2, message="JobPlacement",) hadoop_job = proto.Field( - proto.MESSAGE, number=3, oneof="type_job", message=HadoopJob, + proto.MESSAGE, number=3, oneof="type_job", message="HadoopJob", ) spark_job = proto.Field( - proto.MESSAGE, number=4, oneof="type_job", message=SparkJob, + proto.MESSAGE, number=4, oneof="type_job", message="SparkJob", ) pyspark_job = proto.Field( - proto.MESSAGE, number=5, oneof="type_job", message=PySparkJob, + proto.MESSAGE, number=5, oneof="type_job", message="PySparkJob", ) - hive_job = proto.Field(proto.MESSAGE, number=6, oneof="type_job", message=HiveJob,) + hive_job = proto.Field( + proto.MESSAGE, number=6, oneof="type_job", message="HiveJob", + ) - pig_job = proto.Field(proto.MESSAGE, number=7, oneof="type_job", message=PigJob,) + pig_job = proto.Field(proto.MESSAGE, number=7, oneof="type_job", message="PigJob",) spark_r_job = proto.Field( - proto.MESSAGE, number=21, oneof="type_job", message=SparkRJob, + proto.MESSAGE, number=21, oneof="type_job", message="SparkRJob", ) spark_sql_job = proto.Field( - proto.MESSAGE, number=12, oneof="type_job", message=SparkSqlJob, + proto.MESSAGE, number=12, oneof="type_job", message="SparkSqlJob", ) presto_job = proto.Field( - proto.MESSAGE, number=23, oneof="type_job", message=PrestoJob, + proto.MESSAGE, number=23, oneof="type_job", message="PrestoJob", ) - status = proto.Field(proto.MESSAGE, number=8, message=JobStatus,) + status = proto.Field(proto.MESSAGE, number=8, message="JobStatus",) - status_history = proto.RepeatedField(proto.MESSAGE, number=13, message=JobStatus,) + status_history = proto.RepeatedField(proto.MESSAGE, number=13, message="JobStatus",) yarn_applications = proto.RepeatedField( - proto.MESSAGE, number=9, message=YarnApplication, + proto.MESSAGE, number=9, message="YarnApplication", ) driver_output_resource_uri = proto.Field(proto.STRING, number=17) @@ -843,7 +845,7 @@ class SubmitJobRequest(proto.Message): region = proto.Field(proto.STRING, number=3) - job = proto.Field(proto.MESSAGE, number=2, message=Job,) + job = proto.Field(proto.MESSAGE, number=2, message="Job",) request_id = proto.Field(proto.STRING, number=4) @@ -864,7 +866,7 @@ class JobMetadata(proto.Message): job_id = proto.Field(proto.STRING, number=1) - status = proto.Field(proto.MESSAGE, number=2, message=JobStatus,) + status = proto.Field(proto.MESSAGE, number=2, message="JobStatus",) operation_type = proto.Field(proto.STRING, number=3) @@ -988,7 +990,7 @@ class UpdateJobRequest(proto.Message): job_id = proto.Field(proto.STRING, number=3) - job = proto.Field(proto.MESSAGE, number=4, message=Job,) + job = proto.Field(proto.MESSAGE, number=4, message="Job",) update_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -1010,7 +1012,7 @@ class ListJobsResponse(proto.Message): def raw_page(self): return self - jobs = proto.RepeatedField(proto.MESSAGE, number=1, message=Job,) + jobs = proto.RepeatedField(proto.MESSAGE, number=1, message="Job",) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/dataproc_v1/types/operations.py b/google/cloud/dataproc_v1/types/operations.py index e059814a..042e8c77 100644 --- a/google/cloud/dataproc_v1/types/operations.py +++ b/google/cloud/dataproc_v1/types/operations.py @@ -91,10 +91,10 @@ class ClusterOperationMetadata(proto.Message): cluster_uuid = proto.Field(proto.STRING, number=8) - status = proto.Field(proto.MESSAGE, number=9, message=ClusterOperationStatus,) + status = proto.Field(proto.MESSAGE, number=9, message="ClusterOperationStatus",) status_history = proto.RepeatedField( - proto.MESSAGE, number=10, message=ClusterOperationStatus, + proto.MESSAGE, number=10, message="ClusterOperationStatus", ) operation_type = proto.Field(proto.STRING, number=11) diff --git a/google/cloud/dataproc_v1/types/workflow_templates.py b/google/cloud/dataproc_v1/types/workflow_templates.py index 5d9182f1..50e8a469 100644 --- a/google/cloud/dataproc_v1/types/workflow_templates.py +++ b/google/cloud/dataproc_v1/types/workflow_templates.py @@ -608,7 +608,7 @@ class CreateWorkflowTemplateRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - template = proto.Field(proto.MESSAGE, number=2, message=WorkflowTemplate,) + template = proto.Field(proto.MESSAGE, number=2, message="WorkflowTemplate",) class GetWorkflowTemplateRequest(proto.Message): @@ -726,7 +726,7 @@ class InstantiateInlineWorkflowTemplateRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - template = proto.Field(proto.MESSAGE, number=2, message=WorkflowTemplate,) + template = proto.Field(proto.MESSAGE, number=2, message="WorkflowTemplate",) request_id = proto.Field(proto.STRING, number=3) @@ -742,7 +742,7 @@ class UpdateWorkflowTemplateRequest(proto.Message): version. """ - template = proto.Field(proto.MESSAGE, number=1, message=WorkflowTemplate,) + template = proto.Field(proto.MESSAGE, number=1, message="WorkflowTemplate",) class ListWorkflowTemplatesRequest(proto.Message): @@ -795,7 +795,9 @@ class ListWorkflowTemplatesResponse(proto.Message): def raw_page(self): return self - templates = proto.RepeatedField(proto.MESSAGE, number=1, message=WorkflowTemplate,) + templates = proto.RepeatedField( + proto.MESSAGE, number=1, message="WorkflowTemplate", + ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/async_client.py b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/async_client.py index 36274045..d3f3c9c9 100644 --- a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/async_client.py +++ b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/async_client.py @@ -31,7 +31,7 @@ from google.cloud.dataproc_v1beta2.services.autoscaling_policy_service import pagers from google.cloud.dataproc_v1beta2.types import autoscaling_policies -from .transports.base import AutoscalingPolicyServiceTransport +from .transports.base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import AutoscalingPolicyServiceGrpcAsyncIOTransport from .client import AutoscalingPolicyServiceClient @@ -49,10 +49,55 @@ class AutoscalingPolicyServiceAsyncClient: autoscaling_policy_path = staticmethod( AutoscalingPolicyServiceClient.autoscaling_policy_path ) + parse_autoscaling_policy_path = staticmethod( + AutoscalingPolicyServiceClient.parse_autoscaling_policy_path + ) + + common_billing_account_path = staticmethod( + AutoscalingPolicyServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(AutoscalingPolicyServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + AutoscalingPolicyServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod( + AutoscalingPolicyServiceClient.common_project_path + ) + parse_common_project_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod( + AutoscalingPolicyServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + AutoscalingPolicyServiceClient.parse_common_location_path + ) from_service_account_file = AutoscalingPolicyServiceClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> AutoscalingPolicyServiceTransport: + """Return the transport used by the client instance. + + Returns: + AutoscalingPolicyServiceTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(AutoscalingPolicyServiceClient).get_transport_class, type(AutoscalingPolicyServiceClient), @@ -64,6 +109,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, AutoscalingPolicyServiceTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the autoscaling policy service client. @@ -79,16 +125,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -96,7 +145,10 @@ def __init__( """ self._client = AutoscalingPolicyServiceClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def create_autoscaling_policy( @@ -153,7 +205,8 @@ async def create_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, policy]): + has_flattened_params = any([parent, policy]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -174,7 +227,7 @@ async def create_autoscaling_policy( rpc = gapic_v1.method_async.wrap_method( self._client._transport.create_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -229,7 +282,8 @@ async def update_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([policy]): + has_flattened_params = any([policy]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -252,11 +306,11 @@ async def update_autoscaling_policy( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -320,7 +374,8 @@ async def get_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -343,11 +398,11 @@ async def get_autoscaling_policy( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -412,7 +467,8 @@ async def list_autoscaling_policies( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -435,11 +491,11 @@ async def list_autoscaling_policies( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -507,7 +563,8 @@ async def delete_autoscaling_policy( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -526,7 +583,7 @@ async def delete_autoscaling_policy( rpc = gapic_v1.method_async.wrap_method( self._client._transport.delete_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -542,11 +599,11 @@ async def delete_autoscaling_policy( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("AutoscalingPolicyServiceAsyncClient",) diff --git a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/client.py b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/client.py index 5bbc3b2d..bc80019f 100644 --- a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/client.py +++ b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/client.py @@ -16,24 +16,26 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore from google.cloud.dataproc_v1beta2.services.autoscaling_policy_service import pagers from google.cloud.dataproc_v1beta2.types import autoscaling_policies -from .transports.base import AutoscalingPolicyServiceTransport +from .transports.base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import AutoscalingPolicyServiceGrpcTransport from .transports.grpc_asyncio import AutoscalingPolicyServiceGrpcAsyncIOTransport @@ -132,6 +134,15 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> AutoscalingPolicyServiceTransport: + """Return the transport used by the client instance. + + Returns: + AutoscalingPolicyServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod def autoscaling_policy_path( project: str, location: str, autoscaling_policy: str, @@ -150,12 +161,72 @@ def parse_autoscaling_policy_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, AutoscalingPolicyServiceTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, AutoscalingPolicyServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the autoscaling policy service client. @@ -168,48 +239,74 @@ def __init__( transport (Union[str, ~.AutoscalingPolicyServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -233,11 +330,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_autoscaling_policy( @@ -677,11 +774,11 @@ def delete_autoscaling_policy( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("AutoscalingPolicyServiceClient",) diff --git a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/base.py b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/base.py index f4da8ccb..bc039c5e 100644 --- a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/base.py +++ b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -30,11 +30,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class AutoscalingPolicyServiceTransport(abc.ABC): @@ -50,6 +50,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -67,6 +68,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -94,15 +100,15 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_autoscaling_policy: gapic_v1.method.wrap_method( self.create_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.update_autoscaling_policy: gapic_v1.method.wrap_method( self.update_autoscaling_policy, @@ -111,11 +117,11 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.get_autoscaling_policy: gapic_v1.method.wrap_method( self.get_autoscaling_policy, @@ -124,11 +130,11 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.list_autoscaling_policies: gapic_v1.method.wrap_method( self.list_autoscaling_policies, @@ -137,16 +143,16 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( - exceptions.ServiceUnavailable, exceptions.DeadlineExceeded, + exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.delete_autoscaling_policy: gapic_v1.method.wrap_method( self.delete_autoscaling_policy, default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc.py b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc.py index f1b5b894..ace75125 100644 --- a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc.py +++ b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc.py @@ -15,20 +15,21 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1beta2.types import autoscaling_policies from google.protobuf import empty_pb2 as empty # type: ignore -from .base import AutoscalingPolicyServiceTransport +from .base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO class AutoscalingPolicyServiceGrpcTransport(AutoscalingPolicyServiceTransport): @@ -57,7 +58,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -76,16 +79,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -93,6 +103,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -100,7 +112,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -131,6 +149,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -141,6 +177,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -151,7 +188,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -185,24 +222,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc_asyncio.py b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc_asyncio.py index fa17bb26..f5a39178 100644 --- a/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc_asyncio.py @@ -15,9 +15,12 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -27,7 +30,7 @@ from google.cloud.dataproc_v1beta2.types import autoscaling_policies from google.protobuf import empty_pb2 as empty # type: ignore -from .base import AutoscalingPolicyServiceTransport +from .base import AutoscalingPolicyServiceTransport, DEFAULT_CLIENT_INFO from .grpc import AutoscalingPolicyServiceGrpcTransport @@ -99,7 +102,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -119,16 +124,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -136,6 +148,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -143,13 +157,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -169,6 +194,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -177,6 +220,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -188,13 +232,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1beta2/services/cluster_controller/async_client.py b/google/cloud/dataproc_v1beta2/services/cluster_controller/async_client.py index 9a5af0d9..818d4287 100644 --- a/google/cloud/dataproc_v1beta2/services/cluster_controller/async_client.py +++ b/google/cloud/dataproc_v1beta2/services/cluster_controller/async_client.py @@ -28,15 +28,15 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1beta2.services.cluster_controller import pagers from google.cloud.dataproc_v1beta2.types import clusters from google.cloud.dataproc_v1beta2.types import operations from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from .transports.base import ClusterControllerTransport +from .transports.base import ClusterControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import ClusterControllerGrpcAsyncIOTransport from .client import ClusterControllerClient @@ -51,9 +51,47 @@ class ClusterControllerAsyncClient: DEFAULT_ENDPOINT = ClusterControllerClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = ClusterControllerClient.DEFAULT_MTLS_ENDPOINT + common_billing_account_path = staticmethod( + ClusterControllerClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ClusterControllerClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(ClusterControllerClient.common_folder_path) + parse_common_folder_path = staticmethod( + ClusterControllerClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + ClusterControllerClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + ClusterControllerClient.parse_common_organization_path + ) + + common_project_path = staticmethod(ClusterControllerClient.common_project_path) + parse_common_project_path = staticmethod( + ClusterControllerClient.parse_common_project_path + ) + + common_location_path = staticmethod(ClusterControllerClient.common_location_path) + parse_common_location_path = staticmethod( + ClusterControllerClient.parse_common_location_path + ) + from_service_account_file = ClusterControllerClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> ClusterControllerTransport: + """Return the transport used by the client instance. + + Returns: + ClusterControllerTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(ClusterControllerClient).get_transport_class, type(ClusterControllerClient) ) @@ -64,6 +102,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, ClusterControllerTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the cluster controller client. @@ -79,16 +118,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -96,7 +138,10 @@ def __init__( """ self._client = ClusterControllerClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def create_cluster( @@ -156,7 +201,8 @@ async def create_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster]): + has_flattened_params = any([project_id, region, cluster]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -185,7 +231,7 @@ async def create_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -339,9 +385,10 @@ async def update_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any( + has_flattened_params = any( [project_id, region, cluster_name, cluster, update_mask] - ): + ) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -374,7 +421,7 @@ async def update_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -459,7 +506,8 @@ async def delete_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster_name]): + has_flattened_params = any([project_id, region, cluster_name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -488,7 +536,7 @@ async def delete_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -558,7 +606,8 @@ async def get_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster_name]): + has_flattened_params = any([project_id, region, cluster_name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -585,13 +634,13 @@ async def get_cluster( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -677,7 +726,8 @@ async def list_clusters( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, filter]): + has_flattened_params = any([project_id, region, filter]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -704,13 +754,13 @@ async def list_clusters( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -797,7 +847,8 @@ async def diagnose_cluster( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, cluster_name]): + has_flattened_params = any([project_id, region, cluster_name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -826,7 +877,7 @@ async def diagnose_cluster( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -845,11 +896,11 @@ async def diagnose_cluster( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("ClusterControllerAsyncClient",) diff --git a/google/cloud/dataproc_v1beta2/services/cluster_controller/client.py b/google/cloud/dataproc_v1beta2/services/cluster_controller/client.py index 341e9622..f99564b7 100644 --- a/google/cloud/dataproc_v1beta2/services/cluster_controller/client.py +++ b/google/cloud/dataproc_v1beta2/services/cluster_controller/client.py @@ -16,29 +16,31 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1beta2.services.cluster_controller import pagers from google.cloud.dataproc_v1beta2.types import clusters from google.cloud.dataproc_v1beta2.types import operations from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from .transports.base import ClusterControllerTransport +from .transports.base import ClusterControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import ClusterControllerGrpcTransport from .transports.grpc_asyncio import ClusterControllerGrpcAsyncIOTransport @@ -137,12 +139,81 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> ClusterControllerTransport: + """Return the transport used by the client instance. + + Returns: + ClusterControllerTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, ClusterControllerTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ClusterControllerTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the cluster controller client. @@ -155,48 +226,74 @@ def __init__( transport (Union[str, ~.ClusterControllerTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -220,11 +317,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_cluster( @@ -941,11 +1038,11 @@ def diagnose_cluster( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("ClusterControllerClient",) diff --git a/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/base.py b/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/base.py index 864494b6..5e0d3298 100644 --- a/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/base.py +++ b/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -31,11 +31,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class ClusterControllerTransport(abc.ABC): @@ -51,6 +51,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -68,6 +69,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -95,9 +101,9 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_cluster: gapic_v1.method.wrap_method( @@ -109,7 +115,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.update_cluster: gapic_v1.method.wrap_method( self.update_cluster, @@ -120,7 +126,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.delete_cluster: gapic_v1.method.wrap_method( self.delete_cluster, @@ -131,7 +137,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.get_cluster: gapic_v1.method.wrap_method( self.get_cluster, @@ -140,13 +146,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.list_clusters: gapic_v1.method.wrap_method( self.list_clusters, @@ -155,13 +161,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), self.diagnose_cluster: gapic_v1.method.wrap_method( self.diagnose_cluster, @@ -172,7 +178,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=300.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc.py b/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc.py index abb5622f..c8b16361 100644 --- a/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc.py +++ b/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc.py @@ -15,21 +15,22 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1beta2.types import clusters from google.longrunning import operations_pb2 as operations # type: ignore -from .base import ClusterControllerTransport +from .base import ClusterControllerTransport, DEFAULT_CLIENT_INFO class ClusterControllerGrpcTransport(ClusterControllerTransport): @@ -58,7 +59,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -77,16 +80,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -94,6 +104,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -101,7 +113,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -132,6 +150,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -142,6 +178,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -152,7 +189,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -186,24 +223,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc_asyncio.py b/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc_asyncio.py index 4d778267..0f17284b 100644 --- a/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc_asyncio.py @@ -15,10 +15,13 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -28,7 +31,7 @@ from google.cloud.dataproc_v1beta2.types import clusters from google.longrunning import operations_pb2 as operations # type: ignore -from .base import ClusterControllerTransport +from .base import ClusterControllerTransport, DEFAULT_CLIENT_INFO from .grpc import ClusterControllerGrpcTransport @@ -100,7 +103,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -120,16 +125,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -137,6 +149,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -144,13 +158,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -170,6 +195,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -178,6 +221,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -189,13 +233,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1beta2/services/job_controller/async_client.py b/google/cloud/dataproc_v1beta2/services/job_controller/async_client.py index b83e2612..57234d85 100644 --- a/google/cloud/dataproc_v1beta2/services/job_controller/async_client.py +++ b/google/cloud/dataproc_v1beta2/services/job_controller/async_client.py @@ -28,12 +28,12 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1beta2.services.job_controller import pagers from google.cloud.dataproc_v1beta2.types import jobs -from .transports.base import JobControllerTransport +from .transports.base import JobControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import JobControllerGrpcAsyncIOTransport from .client import JobControllerClient @@ -46,9 +46,47 @@ class JobControllerAsyncClient: DEFAULT_ENDPOINT = JobControllerClient.DEFAULT_ENDPOINT DEFAULT_MTLS_ENDPOINT = JobControllerClient.DEFAULT_MTLS_ENDPOINT + common_billing_account_path = staticmethod( + JobControllerClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + JobControllerClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(JobControllerClient.common_folder_path) + parse_common_folder_path = staticmethod( + JobControllerClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + JobControllerClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + JobControllerClient.parse_common_organization_path + ) + + common_project_path = staticmethod(JobControllerClient.common_project_path) + parse_common_project_path = staticmethod( + JobControllerClient.parse_common_project_path + ) + + common_location_path = staticmethod(JobControllerClient.common_location_path) + parse_common_location_path = staticmethod( + JobControllerClient.parse_common_location_path + ) + from_service_account_file = JobControllerClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> JobControllerTransport: + """Return the transport used by the client instance. + + Returns: + JobControllerTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(JobControllerClient).get_transport_class, type(JobControllerClient) ) @@ -59,6 +97,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, JobControllerTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the job controller client. @@ -74,16 +113,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -91,7 +133,10 @@ def __init__( """ self._client = JobControllerClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def submit_job( @@ -142,7 +187,8 @@ async def submit_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job]): + has_flattened_params = any([project_id, region, job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -171,7 +217,7 @@ async def submit_job( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -232,7 +278,8 @@ async def submit_job_as_operation( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job]): + has_flattened_params = any([project_id, region, job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -261,7 +308,7 @@ async def submit_job_as_operation( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -328,7 +375,8 @@ async def get_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job_id]): + has_flattened_params = any([project_id, region, job_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -355,13 +403,13 @@ async def get_job( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -438,7 +486,8 @@ async def list_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, filter]): + has_flattened_params = any([project_id, region, filter]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -465,13 +514,13 @@ async def list_jobs( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -525,7 +574,7 @@ async def update_job( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -586,7 +635,8 @@ async def cancel_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job_id]): + has_flattened_params = any([project_id, region, job_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -613,13 +663,13 @@ async def cancel_job( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -673,7 +723,8 @@ async def delete_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([project_id, region, job_id]): + has_flattened_params = any([project_id, region, job_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -702,7 +753,7 @@ async def delete_job( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Send the request. @@ -712,11 +763,11 @@ async def delete_job( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("JobControllerAsyncClient",) diff --git a/google/cloud/dataproc_v1beta2/services/job_controller/client.py b/google/cloud/dataproc_v1beta2/services/job_controller/client.py index e34798cd..0989f37a 100644 --- a/google/cloud/dataproc_v1beta2/services/job_controller/client.py +++ b/google/cloud/dataproc_v1beta2/services/job_controller/client.py @@ -16,26 +16,28 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1beta2.services.job_controller import pagers from google.cloud.dataproc_v1beta2.types import jobs -from .transports.base import JobControllerTransport +from .transports.base import JobControllerTransport, DEFAULT_CLIENT_INFO from .transports.grpc import JobControllerGrpcTransport from .transports.grpc_asyncio import JobControllerGrpcAsyncIOTransport @@ -128,12 +130,81 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> JobControllerTransport: + """Return the transport used by the client instance. + + Returns: + JobControllerTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, JobControllerTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobControllerTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the job controller client. @@ -146,48 +217,74 @@ def __init__( transport (Union[str, ~.JobControllerTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -211,11 +308,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def submit_job( @@ -795,11 +892,11 @@ def delete_job( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("JobControllerClient",) diff --git a/google/cloud/dataproc_v1beta2/services/job_controller/transports/base.py b/google/cloud/dataproc_v1beta2/services/job_controller/transports/base.py index 99a86d34..deea5d1c 100644 --- a/google/cloud/dataproc_v1beta2/services/job_controller/transports/base.py +++ b/google/cloud/dataproc_v1beta2/services/job_controller/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -32,11 +32,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class JobControllerTransport(abc.ABC): @@ -52,6 +52,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -69,6 +70,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -96,9 +102,9 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.submit_job: gapic_v1.method.wrap_method( @@ -110,7 +116,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.submit_job_as_operation: gapic_v1.method.wrap_method( self.submit_job_as_operation, @@ -121,7 +127,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.get_job: gapic_v1.method.wrap_method( self.get_job, @@ -130,13 +136,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.list_jobs: gapic_v1.method.wrap_method( self.list_jobs, @@ -145,13 +151,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.update_job: gapic_v1.method.wrap_method( self.update_job, @@ -162,7 +168,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.cancel_job: gapic_v1.method.wrap_method( self.cancel_job, @@ -171,13 +177,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), self.delete_job: gapic_v1.method.wrap_method( self.delete_job, @@ -188,7 +194,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=900.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc.py b/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc.py index 19aa92cc..800181a1 100644 --- a/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc.py +++ b/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc.py @@ -15,22 +15,23 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1beta2.types import jobs from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import JobControllerTransport +from .base import JobControllerTransport, DEFAULT_CLIENT_INFO class JobControllerGrpcTransport(JobControllerTransport): @@ -58,7 +59,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -77,16 +80,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -94,6 +104,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -101,7 +113,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -132,6 +150,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -142,6 +178,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -152,7 +189,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -186,24 +223,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc_asyncio.py b/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc_asyncio.py index fcc056ab..5d23c945 100644 --- a/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc_asyncio.py @@ -15,10 +15,13 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -29,7 +32,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import JobControllerTransport +from .base import JobControllerTransport, DEFAULT_CLIENT_INFO from .grpc import JobControllerGrpcTransport @@ -100,7 +103,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -120,16 +125,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -137,6 +149,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -144,13 +158,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -170,6 +195,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -178,6 +221,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -189,13 +233,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1beta2/services/workflow_template_service/async_client.py b/google/cloud/dataproc_v1beta2/services/workflow_template_service/async_client.py index 94e84da0..e6dbfb51 100644 --- a/google/cloud/dataproc_v1beta2/services/workflow_template_service/async_client.py +++ b/google/cloud/dataproc_v1beta2/services/workflow_template_service/async_client.py @@ -28,14 +28,14 @@ from google.auth import credentials # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1beta2.services.workflow_template_service import pagers from google.cloud.dataproc_v1beta2.types import workflow_templates from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import WorkflowTemplateServiceTransport +from .transports.base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc_asyncio import WorkflowTemplateServiceGrpcAsyncIOTransport from .client import WorkflowTemplateServiceClient @@ -53,10 +53,55 @@ class WorkflowTemplateServiceAsyncClient: workflow_template_path = staticmethod( WorkflowTemplateServiceClient.workflow_template_path ) + parse_workflow_template_path = staticmethod( + WorkflowTemplateServiceClient.parse_workflow_template_path + ) + + common_billing_account_path = staticmethod( + WorkflowTemplateServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(WorkflowTemplateServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + WorkflowTemplateServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod( + WorkflowTemplateServiceClient.common_project_path + ) + parse_common_project_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod( + WorkflowTemplateServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + WorkflowTemplateServiceClient.parse_common_location_path + ) from_service_account_file = WorkflowTemplateServiceClient.from_service_account_file from_service_account_json = from_service_account_file + @property + def transport(self) -> WorkflowTemplateServiceTransport: + """Return the transport used by the client instance. + + Returns: + WorkflowTemplateServiceTransport: The transport used by the client instance. + """ + return self._client.transport + get_transport_class = functools.partial( type(WorkflowTemplateServiceClient).get_transport_class, type(WorkflowTemplateServiceClient), @@ -68,6 +113,7 @@ def __init__( credentials: credentials.Credentials = None, transport: Union[str, WorkflowTemplateServiceTransport] = "grpc_asyncio", client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the workflow template service client. @@ -83,16 +129,19 @@ def __init__( client_options (ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -100,7 +149,10 @@ def __init__( """ self._client = WorkflowTemplateServiceClient( - credentials=credentials, transport=transport, client_options=client_options, + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, ) async def create_workflow_template( @@ -157,7 +209,8 @@ async def create_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, template]): + has_flattened_params = any([parent, template]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -184,7 +237,7 @@ async def create_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -249,7 +302,8 @@ async def get_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -272,13 +326,13 @@ async def get_workflow_template( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -389,7 +443,8 @@ async def instantiate_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, parameters]): + has_flattened_params = any([name, parameters]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -402,8 +457,9 @@ async def instantiate_workflow_template( if name is not None: request.name = name - if parameters is not None: - request.parameters = parameters + + if parameters: + request.parameters.update(parameters) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -416,7 +472,7 @@ async def instantiate_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -535,7 +591,8 @@ async def instantiate_inline_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, template]): + has_flattened_params = any([parent, template]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -562,7 +619,7 @@ async def instantiate_inline_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -626,7 +683,8 @@ async def update_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([template]): + has_flattened_params = any([template]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -651,7 +709,7 @@ async def update_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -719,7 +777,8 @@ async def list_workflow_templates( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -742,13 +801,13 @@ async def list_workflow_templates( maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -814,7 +873,8 @@ async def delete_workflow_template( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." @@ -839,7 +899,7 @@ async def delete_workflow_template( predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=DEFAULT_CLIENT_INFO, ) # Certain fields should be provided within the metadata header; @@ -855,11 +915,11 @@ async def delete_workflow_template( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("WorkflowTemplateServiceAsyncClient",) diff --git a/google/cloud/dataproc_v1beta2/services/workflow_template_service/client.py b/google/cloud/dataproc_v1beta2/services/workflow_template_service/client.py index 5c529342..f383a694 100644 --- a/google/cloud/dataproc_v1beta2/services/workflow_template_service/client.py +++ b/google/cloud/dataproc_v1beta2/services/workflow_template_service/client.py @@ -16,28 +16,30 @@ # from collections import OrderedDict +from distutils import util import os import re -from typing import Callable, Dict, Sequence, Tuple, Type, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation -from google.api_core import operation_async +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.dataproc_v1beta2.services.workflow_template_service import pagers from google.cloud.dataproc_v1beta2.types import workflow_templates from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import WorkflowTemplateServiceTransport +from .transports.base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import WorkflowTemplateServiceGrpcTransport from .transports.grpc_asyncio import WorkflowTemplateServiceGrpcAsyncIOTransport @@ -136,6 +138,15 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> WorkflowTemplateServiceTransport: + """Return the transport used by the client instance. + + Returns: + WorkflowTemplateServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod def workflow_template_path( project: str, region: str, workflow_template: str, @@ -154,12 +165,72 @@ def parse_workflow_template_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, WorkflowTemplateServiceTransport] = None, - client_options: ClientOptions = None, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, WorkflowTemplateServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the workflow template service client. @@ -172,48 +243,74 @@ def __init__( transport (Union[str, ~.WorkflowTemplateServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT environment variable can also be used to override the endpoint: "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint, this is the default value for - the environment variable) and "auto" (auto switch to the default - mTLS endpoint if client SSL credentials is present). However, - the ``api_endpoint`` property takes precedence if provided. - (2) The ``client_cert_source`` property is used to provide client - SSL credentials for mutual TLS transport. If not provided, the - default SSL credentials will be used if present. + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) if client_options is None: - client_options = ClientOptions.ClientOptions() + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None - if client_options.api_endpoint is None: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS", "never") + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_mtls_env == "never": - client_options.api_endpoint = self.DEFAULT_ENDPOINT + api_endpoint = self.DEFAULT_ENDPOINT elif use_mtls_env == "always": - client_options.api_endpoint = self.DEFAULT_MTLS_ENDPOINT + api_endpoint = self.DEFAULT_MTLS_ENDPOINT elif use_mtls_env == "auto": - has_client_cert_source = ( - client_options.client_cert_source is not None - or mtls.has_default_client_cert_source() - ) - client_options.api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT - if has_client_cert_source - else self.DEFAULT_ENDPOINT + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT ) else: raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS value. Accepted values: never, auto, always" + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" ) # Save or instantiate the transport. @@ -237,11 +334,11 @@ def __init__( self._transport = Transport( credentials=credentials, credentials_file=client_options.credentials_file, - host=client_options.api_endpoint, + host=api_endpoint, scopes=client_options.scopes, - api_mtls_endpoint=client_options.api_endpoint, - client_cert_source=client_options.client_cert_source, + ssl_channel_credentials=ssl_credentials, quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_workflow_template( @@ -539,8 +636,9 @@ def instantiate_workflow_template( if name is not None: request.name = name - if parameters is not None: - request.parameters = parameters + + if parameters: + request.parameters.update(parameters) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. @@ -970,11 +1068,11 @@ def delete_workflow_template( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("WorkflowTemplateServiceClient",) diff --git a/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/base.py b/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/base.py index 3dc6c0ad..2495d556 100644 --- a/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/base.py +++ b/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/base.py @@ -19,7 +19,7 @@ import typing import pkg_resources -from google import auth +from google import auth # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore @@ -32,11 +32,11 @@ try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution("google-cloud-dataproc",).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() class WorkflowTemplateServiceTransport(abc.ABC): @@ -52,6 +52,7 @@ def __init__( credentials_file: typing.Optional[str] = None, scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, **kwargs, ) -> None: """Instantiate the transport. @@ -69,6 +70,11 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -96,9 +102,9 @@ def __init__( self._credentials = credentials # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages() + self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self): + def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { self.create_workflow_template: gapic_v1.method.wrap_method( @@ -110,7 +116,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.get_workflow_template: gapic_v1.method.wrap_method( self.get_workflow_template, @@ -119,13 +125,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.instantiate_workflow_template: gapic_v1.method.wrap_method( self.instantiate_workflow_template, @@ -136,7 +142,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.instantiate_inline_workflow_template: gapic_v1.method.wrap_method( self.instantiate_inline_workflow_template, @@ -147,7 +153,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.update_workflow_template: gapic_v1.method.wrap_method( self.update_workflow_template, @@ -158,7 +164,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.list_workflow_templates: gapic_v1.method.wrap_method( self.list_workflow_templates, @@ -167,13 +173,13 @@ def _prep_wrapped_messages(self): maximum=60.0, multiplier=1.3, predicate=retries.if_exception_type( + exceptions.DeadlineExceeded, exceptions.InternalServerError, exceptions.ServiceUnavailable, - exceptions.DeadlineExceeded, ), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), self.delete_workflow_template: gapic_v1.method.wrap_method( self.delete_workflow_template, @@ -184,7 +190,7 @@ def _prep_wrapped_messages(self): predicate=retries.if_exception_type(exceptions.ServiceUnavailable,), ), default_timeout=600.0, - client_info=_client_info, + client_info=client_info, ), } diff --git a/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc.py b/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc.py index 6ab10372..b78dc810 100644 --- a/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc.py +++ b/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc.py @@ -15,22 +15,23 @@ # limitations under the License. # +import warnings from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - import grpc # type: ignore from google.cloud.dataproc_v1beta2.types import workflow_templates from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import WorkflowTemplateServiceTransport +from .base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO class WorkflowTemplateServiceGrpcTransport(WorkflowTemplateServiceTransport): @@ -59,7 +60,9 @@ def __init__( channel: grpc.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - quota_project_id: Optional[str] = None + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -78,16 +81,23 @@ def __init__( ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport @@ -95,6 +105,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -102,7 +114,13 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -133,6 +151,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) self._stubs = {} # type: Dict[str, Callable] @@ -143,6 +179,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) @classmethod @@ -153,7 +190,7 @@ def create_channel( credentials_file: str = None, scopes: Optional[Sequence[str]] = None, quota_project_id: Optional[str] = None, - **kwargs + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -187,24 +224,13 @@ def create_channel( credentials_file=credentials_file, scopes=scopes, quota_project_id=quota_project_id, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc_asyncio.py b/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc_asyncio.py index d085b7b1..11921398 100644 --- a/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc_asyncio.py +++ b/google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc_asyncio.py @@ -15,10 +15,13 @@ # limitations under the License. # +import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple +from google.api_core import gapic_v1 # type: ignore from google.api_core import grpc_helpers_async # type: ignore from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore @@ -29,7 +32,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import WorkflowTemplateServiceTransport +from .base import WorkflowTemplateServiceTransport, DEFAULT_CLIENT_INFO from .grpc import WorkflowTemplateServiceGrpcTransport @@ -101,7 +104,9 @@ def __init__( channel: aio.Channel = None, api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -121,16 +126,23 @@ def __init__( are passed to :func:`google.auth.default`. channel (Optional[aio.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): The mutual TLS endpoint. If - provided, it overrides the ``host`` argument and tries to create + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create a mutual TLS channel with client SSL credentials from ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): A - callback to provide client SSL certificate bytes and private key - bytes, both in PEM format. It is ignored if ``api_mtls_endpoint`` - is None. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. Raises: google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport @@ -138,6 +150,8 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -145,13 +159,24 @@ def __init__( # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + host = ( api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443" ) + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + # Create SSL credentials with client_cert_source or application # default SSL credentials. if client_cert_source: @@ -171,6 +196,24 @@ def __init__( scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) # Run the base constructor. super().__init__( @@ -179,6 +222,7 @@ def __init__( credentials_file=credentials_file, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + client_info=client_info, ) self._stubs = {} @@ -190,13 +234,6 @@ def grpc_channel(self) -> aio.Channel: This property caches on the instance; repeated calls return the same channel. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/dataproc_v1beta2/types/autoscaling_policies.py b/google/cloud/dataproc_v1beta2/types/autoscaling_policies.py index 453f4954..ebc355c6 100644 --- a/google/cloud/dataproc_v1beta2/types/autoscaling_policies.py +++ b/google/cloud/dataproc_v1beta2/types/autoscaling_policies.py @@ -245,7 +245,7 @@ class CreateAutoscalingPolicyRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - policy = proto.Field(proto.MESSAGE, number=2, message=AutoscalingPolicy,) + policy = proto.Field(proto.MESSAGE, number=2, message="AutoscalingPolicy",) class GetAutoscalingPolicyRequest(proto.Message): @@ -277,7 +277,7 @@ class UpdateAutoscalingPolicyRequest(proto.Message): Required. The updated autoscaling policy. """ - policy = proto.Field(proto.MESSAGE, number=1, message=AutoscalingPolicy,) + policy = proto.Field(proto.MESSAGE, number=1, message="AutoscalingPolicy",) class DeleteAutoscalingPolicyRequest(proto.Message): @@ -352,7 +352,9 @@ class ListAutoscalingPoliciesResponse(proto.Message): def raw_page(self): return self - policies = proto.RepeatedField(proto.MESSAGE, number=1, message=AutoscalingPolicy,) + policies = proto.RepeatedField( + proto.MESSAGE, number=1, message="AutoscalingPolicy", + ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/dataproc_v1beta2/types/clusters.py b/google/cloud/dataproc_v1beta2/types/clusters.py index d2a3a450..bb747d0a 100644 --- a/google/cloud/dataproc_v1beta2/types/clusters.py +++ b/google/cloud/dataproc_v1beta2/types/clusters.py @@ -916,7 +916,7 @@ class CreateClusterRequest(proto.Message): region = proto.Field(proto.STRING, number=3) - cluster = proto.Field(proto.MESSAGE, number=2, message=Cluster,) + cluster = proto.Field(proto.MESSAGE, number=2, message="Cluster",) request_id = proto.Field(proto.STRING, number=4) @@ -1039,7 +1039,7 @@ class UpdateClusterRequest(proto.Message): cluster_name = proto.Field(proto.STRING, number=2) - cluster = proto.Field(proto.MESSAGE, number=3, message=Cluster,) + cluster = proto.Field(proto.MESSAGE, number=3, message="Cluster",) graceful_decommission_timeout = proto.Field( proto.MESSAGE, number=6, message=duration.Duration, @@ -1182,7 +1182,7 @@ class ListClustersResponse(proto.Message): def raw_page(self): return self - clusters = proto.RepeatedField(proto.MESSAGE, number=1, message=Cluster,) + clusters = proto.RepeatedField(proto.MESSAGE, number=1, message="Cluster",) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/dataproc_v1beta2/types/jobs.py b/google/cloud/dataproc_v1beta2/types/jobs.py index b94e9a3c..02a3e4ec 100644 --- a/google/cloud/dataproc_v1beta2/types/jobs.py +++ b/google/cloud/dataproc_v1beta2/types/jobs.py @@ -145,7 +145,7 @@ class HadoopJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=7) - logging_config = proto.Field(proto.MESSAGE, number=8, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=8, message="LoggingConfig",) class SparkJob(proto.Message): @@ -209,7 +209,7 @@ class SparkJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=7) - logging_config = proto.Field(proto.MESSAGE, number=8, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=8, message="LoggingConfig",) class PySparkJob(proto.Message): @@ -269,7 +269,7 @@ class PySparkJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=7) - logging_config = proto.Field(proto.MESSAGE, number=8, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=8, message="LoggingConfig",) class QueryList(proto.Message): @@ -332,7 +332,7 @@ class HiveJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) continue_on_failure = proto.Field(proto.BOOL, number=3) @@ -374,7 +374,7 @@ class SparkSqlJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) script_variables = proto.MapField(proto.STRING, proto.STRING, number=3) @@ -383,7 +383,7 @@ class SparkSqlJob(proto.Message): jar_file_uris = proto.RepeatedField(proto.STRING, number=56) - logging_config = proto.Field(proto.MESSAGE, number=6, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=6, message="LoggingConfig",) class PigJob(proto.Message): @@ -421,7 +421,7 @@ class PigJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) continue_on_failure = proto.Field(proto.BOOL, number=3) @@ -432,7 +432,7 @@ class PigJob(proto.Message): jar_file_uris = proto.RepeatedField(proto.STRING, number=6) - logging_config = proto.Field(proto.MESSAGE, number=7, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=7, message="LoggingConfig",) class SparkRJob(proto.Message): @@ -482,7 +482,7 @@ class SparkRJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=5) - logging_config = proto.Field(proto.MESSAGE, number=6, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=6, message="LoggingConfig",) class PrestoJob(proto.Message): @@ -522,7 +522,7 @@ class PrestoJob(proto.Message): query_file_uri = proto.Field(proto.STRING, number=1, oneof="queries") query_list = proto.Field( - proto.MESSAGE, number=2, oneof="queries", message=QueryList, + proto.MESSAGE, number=2, oneof="queries", message="QueryList", ) continue_on_failure = proto.Field(proto.BOOL, number=3) @@ -533,7 +533,7 @@ class PrestoJob(proto.Message): properties = proto.MapField(proto.STRING, proto.STRING, number=6) - logging_config = proto.Field(proto.MESSAGE, number=7, message=LoggingConfig,) + logging_config = proto.Field(proto.MESSAGE, number=7, message="LoggingConfig",) class JobPlacement(proto.Message): @@ -747,44 +747,46 @@ class Job(proto.Message): will indicate if it was successful, failed, or cancelled. """ - reference = proto.Field(proto.MESSAGE, number=1, message=JobReference,) + reference = proto.Field(proto.MESSAGE, number=1, message="JobReference",) - placement = proto.Field(proto.MESSAGE, number=2, message=JobPlacement,) + placement = proto.Field(proto.MESSAGE, number=2, message="JobPlacement",) hadoop_job = proto.Field( - proto.MESSAGE, number=3, oneof="type_job", message=HadoopJob, + proto.MESSAGE, number=3, oneof="type_job", message="HadoopJob", ) spark_job = proto.Field( - proto.MESSAGE, number=4, oneof="type_job", message=SparkJob, + proto.MESSAGE, number=4, oneof="type_job", message="SparkJob", ) pyspark_job = proto.Field( - proto.MESSAGE, number=5, oneof="type_job", message=PySparkJob, + proto.MESSAGE, number=5, oneof="type_job", message="PySparkJob", ) - hive_job = proto.Field(proto.MESSAGE, number=6, oneof="type_job", message=HiveJob,) + hive_job = proto.Field( + proto.MESSAGE, number=6, oneof="type_job", message="HiveJob", + ) - pig_job = proto.Field(proto.MESSAGE, number=7, oneof="type_job", message=PigJob,) + pig_job = proto.Field(proto.MESSAGE, number=7, oneof="type_job", message="PigJob",) spark_r_job = proto.Field( - proto.MESSAGE, number=21, oneof="type_job", message=SparkRJob, + proto.MESSAGE, number=21, oneof="type_job", message="SparkRJob", ) spark_sql_job = proto.Field( - proto.MESSAGE, number=12, oneof="type_job", message=SparkSqlJob, + proto.MESSAGE, number=12, oneof="type_job", message="SparkSqlJob", ) presto_job = proto.Field( - proto.MESSAGE, number=23, oneof="type_job", message=PrestoJob, + proto.MESSAGE, number=23, oneof="type_job", message="PrestoJob", ) - status = proto.Field(proto.MESSAGE, number=8, message=JobStatus,) + status = proto.Field(proto.MESSAGE, number=8, message="JobStatus",) - status_history = proto.RepeatedField(proto.MESSAGE, number=13, message=JobStatus,) + status_history = proto.RepeatedField(proto.MESSAGE, number=13, message="JobStatus",) yarn_applications = proto.RepeatedField( - proto.MESSAGE, number=9, message=YarnApplication, + proto.MESSAGE, number=9, message="YarnApplication", ) submitted_by = proto.Field(proto.STRING, number=10) @@ -838,7 +840,7 @@ class JobMetadata(proto.Message): job_id = proto.Field(proto.STRING, number=1) - status = proto.Field(proto.MESSAGE, number=2, message=JobStatus,) + status = proto.Field(proto.MESSAGE, number=2, message="JobStatus",) operation_type = proto.Field(proto.STRING, number=3) @@ -878,7 +880,7 @@ class SubmitJobRequest(proto.Message): region = proto.Field(proto.STRING, number=3) - job = proto.Field(proto.MESSAGE, number=2, message=Job,) + job = proto.Field(proto.MESSAGE, number=2, message="Job",) request_id = proto.Field(proto.STRING, number=4) @@ -1000,7 +1002,7 @@ class UpdateJobRequest(proto.Message): job_id = proto.Field(proto.STRING, number=3) - job = proto.Field(proto.MESSAGE, number=4, message=Job,) + job = proto.Field(proto.MESSAGE, number=4, message="Job",) update_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -1022,7 +1024,7 @@ class ListJobsResponse(proto.Message): def raw_page(self): return self - jobs = proto.RepeatedField(proto.MESSAGE, number=1, message=Job,) + jobs = proto.RepeatedField(proto.MESSAGE, number=1, message="Job",) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/dataproc_v1beta2/types/operations.py b/google/cloud/dataproc_v1beta2/types/operations.py index b43dc854..13baf0df 100644 --- a/google/cloud/dataproc_v1beta2/types/operations.py +++ b/google/cloud/dataproc_v1beta2/types/operations.py @@ -91,10 +91,10 @@ class ClusterOperationMetadata(proto.Message): cluster_uuid = proto.Field(proto.STRING, number=8) - status = proto.Field(proto.MESSAGE, number=9, message=ClusterOperationStatus,) + status = proto.Field(proto.MESSAGE, number=9, message="ClusterOperationStatus",) status_history = proto.RepeatedField( - proto.MESSAGE, number=10, message=ClusterOperationStatus, + proto.MESSAGE, number=10, message="ClusterOperationStatus", ) operation_type = proto.Field(proto.STRING, number=11) diff --git a/google/cloud/dataproc_v1beta2/types/workflow_templates.py b/google/cloud/dataproc_v1beta2/types/workflow_templates.py index 50319c8b..31b80e6d 100644 --- a/google/cloud/dataproc_v1beta2/types/workflow_templates.py +++ b/google/cloud/dataproc_v1beta2/types/workflow_templates.py @@ -616,7 +616,7 @@ class CreateWorkflowTemplateRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - template = proto.Field(proto.MESSAGE, number=2, message=WorkflowTemplate,) + template = proto.Field(proto.MESSAGE, number=2, message="WorkflowTemplate",) class GetWorkflowTemplateRequest(proto.Message): @@ -740,7 +740,7 @@ class InstantiateInlineWorkflowTemplateRequest(proto.Message): parent = proto.Field(proto.STRING, number=1) - template = proto.Field(proto.MESSAGE, number=2, message=WorkflowTemplate,) + template = proto.Field(proto.MESSAGE, number=2, message="WorkflowTemplate",) instance_id = proto.Field(proto.STRING, number=3) @@ -758,7 +758,7 @@ class UpdateWorkflowTemplateRequest(proto.Message): version. """ - template = proto.Field(proto.MESSAGE, number=1, message=WorkflowTemplate,) + template = proto.Field(proto.MESSAGE, number=1, message="WorkflowTemplate",) class ListWorkflowTemplatesRequest(proto.Message): @@ -811,7 +811,9 @@ class ListWorkflowTemplatesResponse(proto.Message): def raw_page(self): return self - templates = proto.RepeatedField(proto.MESSAGE, number=1, message=WorkflowTemplate,) + templates = proto.RepeatedField( + proto.MESSAGE, number=1, message="WorkflowTemplate", + ) next_page_token = proto.Field(proto.STRING, number=2) diff --git a/noxfile.py b/noxfile.py index 6374c11f..f3ed2bed 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,7 +28,7 @@ DEFAULT_PYTHON_VERSION = "3.8" SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] -UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"] +UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] @nox.session(python=DEFAULT_PYTHON_VERSION) @@ -72,7 +72,9 @@ def default(session): # Install all test dependencies, then install this package in-place. session.install("asyncmock", "pytest-asyncio") - session.install("mock", "pytest", "pytest-cov") + session.install( + "mock", "pytest", "pytest-cov", + ) session.install("-e", ".") # Run py.test against the unit tests. @@ -173,7 +175,9 @@ def docfx(session): """Build the docfx yaml files for this library.""" session.install("-e", ".") - session.install("sphinx<3.0.0", "alabaster", "recommonmark", "sphinx-docfx-yaml") + # sphinx-docfx-yaml supports up to sphinx version 1.5.5. + # https://github.com/docascode/sphinx-docfx-yaml/issues/97 + session.install("sphinx==1.5.5", "alabaster", "recommonmark", "sphinx-docfx-yaml") shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) session.run( diff --git a/samples/snippets/noxfile.py b/samples/snippets/noxfile.py index ba55d7ce..b90eef00 100644 --- a/samples/snippets/noxfile.py +++ b/samples/snippets/noxfile.py @@ -39,6 +39,10 @@ # You can opt out from the test for specific Python versions. 'ignored_versions': ["2.7"], + # Old samples are opted out of enforcing Python type hints + # All new samples should feature them + 'enforce_type_hints': False, + # An envvar key for determining the project id to use. Change it # to 'BUILD_SPECIFIC_GCLOUD_PROJECT' if you want to opt in using a # build specific Cloud project. You can also use your own string @@ -132,7 +136,10 @@ def _determine_local_import_names(start_dir): @nox.session def lint(session): - session.install("flake8", "flake8-import-order") + if not TEST_CONFIG['enforce_type_hints']: + session.install("flake8", "flake8-import-order") + else: + session.install("flake8", "flake8-import-order", "flake8-annotations") local_names = _determine_local_import_names(".") args = FLAKE8_COMMON_ARGS + [ @@ -141,8 +148,18 @@ def lint(session): "." ] session.run("flake8", *args) +# +# Black +# +@nox.session +def blacken(session): + session.install("black") + python_files = [path for path in os.listdir(".") if path.endswith(".py")] + + session.run("black", *python_files) + # # Sample Tests # @@ -201,6 +218,11 @@ def _get_repo_root(): break if Path(p / ".git").exists(): return str(p) + # .git is not available in repos cloned via Cloud Build + # setup.py is always in the library's root, so use that instead + # https://github.com/googleapis/synthtool/issues/792 + if Path(p / "setup.py").exists(): + return str(p) p = p.parent raise Exception("Unable to detect repository root.") diff --git a/scripts/decrypt-secrets.sh b/scripts/decrypt-secrets.sh index ff599eb2..21f6d2a2 100755 --- a/scripts/decrypt-secrets.sh +++ b/scripts/decrypt-secrets.sh @@ -20,14 +20,27 @@ ROOT=$( dirname "$DIR" ) # Work from the project root. cd $ROOT +# Prevent it from overriding files. +# We recommend that sample authors use their own service account files and cloud project. +# In that case, they are supposed to prepare these files by themselves. +if [[ -f "testing/test-env.sh" ]] || \ + [[ -f "testing/service-account.json" ]] || \ + [[ -f "testing/client-secrets.json" ]]; then + echo "One or more target files exist, aborting." + exit 1 +fi + # Use SECRET_MANAGER_PROJECT if set, fallback to cloud-devrel-kokoro-resources. PROJECT_ID="${SECRET_MANAGER_PROJECT:-cloud-devrel-kokoro-resources}" gcloud secrets versions access latest --secret="python-docs-samples-test-env" \ + --project="${PROJECT_ID}" \ > testing/test-env.sh gcloud secrets versions access latest \ --secret="python-docs-samples-service-account" \ + --project="${PROJECT_ID}" \ > testing/service-account.json gcloud secrets versions access latest \ --secret="python-docs-samples-client-secrets" \ - > testing/client-secrets.json \ No newline at end of file + --project="${PROJECT_ID}" \ + > testing/client-secrets.json diff --git a/scripts/fixup_dataproc_v1_keywords.py b/scripts/fixup_dataproc_v1_keywords.py index 9824550a..92228e53 100644 --- a/scripts/fixup_dataproc_v1_keywords.py +++ b/scripts/fixup_dataproc_v1_keywords.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/scripts/fixup_dataproc_v1beta2_keywords.py b/scripts/fixup_dataproc_v1beta2_keywords.py index ecadef2b..11f2e445 100644 --- a/scripts/fixup_dataproc_v1beta2_keywords.py +++ b/scripts/fixup_dataproc_v1beta2_keywords.py @@ -1,3 +1,4 @@ +#! /usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright 2020 Google LLC diff --git a/synth.metadata b/synth.metadata index d3160be7..b8d02d9f 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,21 +4,29 @@ "git": { "name": ".", "remote": "https://github.com/googleapis/python-dataproc.git", - "sha": "31af47932ebc6d0b5df0ba1a0f7a208ab55578ae" + "sha": "98d81ab998a752f36de3ac1e4c1e0b9e291616ae" + } + }, + { + "git": { + "name": "googleapis", + "remote": "https://github.com/googleapis/googleapis.git", + "sha": "be0bdf86cd31aa7c1a7b30a9a2e9f2fd53ee3d91", + "internalRef": "342353190" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "94421c47802f56a44c320257b2b4c190dc7d6b68" + "sha": "e89175cf074dccc4babb4eca66ae913696e47a71" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "94421c47802f56a44c320257b2b4c190dc7d6b68" + "sha": "e89175cf074dccc4babb4eca66ae913696e47a71" } } ], @@ -41,5 +49,188 @@ "generator": "bazel" } } + ], + "generatedFiles": [ + ".flake8", + ".github/CONTRIBUTING.md", + ".github/ISSUE_TEMPLATE/bug_report.md", + ".github/ISSUE_TEMPLATE/feature_request.md", + ".github/ISSUE_TEMPLATE/support_request.md", + ".github/PULL_REQUEST_TEMPLATE.md", + ".github/release-please.yml", + ".github/snippet-bot.yml", + ".gitignore", + ".kokoro/build.sh", + ".kokoro/continuous/common.cfg", + ".kokoro/continuous/continuous.cfg", + ".kokoro/docker/docs/Dockerfile", + ".kokoro/docker/docs/fetch_gpg_keys.sh", + ".kokoro/docs/common.cfg", + ".kokoro/docs/docs-presubmit.cfg", + ".kokoro/docs/docs.cfg", + ".kokoro/populate-secrets.sh", + ".kokoro/presubmit/common.cfg", + ".kokoro/presubmit/presubmit.cfg", + ".kokoro/publish-docs.sh", + ".kokoro/release.sh", + ".kokoro/release/common.cfg", + ".kokoro/release/release.cfg", + ".kokoro/samples/lint/common.cfg", + ".kokoro/samples/lint/continuous.cfg", + ".kokoro/samples/lint/periodic.cfg", + ".kokoro/samples/lint/presubmit.cfg", + ".kokoro/samples/python3.6/common.cfg", + ".kokoro/samples/python3.6/continuous.cfg", + ".kokoro/samples/python3.6/periodic.cfg", + ".kokoro/samples/python3.6/presubmit.cfg", + ".kokoro/samples/python3.7/common.cfg", + ".kokoro/samples/python3.7/continuous.cfg", + ".kokoro/samples/python3.7/periodic.cfg", + ".kokoro/samples/python3.7/presubmit.cfg", + ".kokoro/samples/python3.8/common.cfg", + ".kokoro/samples/python3.8/continuous.cfg", + ".kokoro/samples/python3.8/periodic.cfg", + ".kokoro/samples/python3.8/presubmit.cfg", + ".kokoro/test-samples.sh", + ".kokoro/trampoline.sh", + ".kokoro/trampoline_v2.sh", + ".trampolinerc", + "CODE_OF_CONDUCT.md", + "CONTRIBUTING.rst", + "LICENSE", + "MANIFEST.in", + "docs/_static/custom.css", + "docs/_templates/layout.html", + "docs/conf.py", + "docs/dataproc_v1/services.rst", + "docs/dataproc_v1/types.rst", + "docs/dataproc_v1beta2/services.rst", + "docs/dataproc_v1beta2/types.rst", + "docs/multiprocessing.rst", + "google/cloud/dataproc/__init__.py", + "google/cloud/dataproc/py.typed", + "google/cloud/dataproc_v1/__init__.py", + "google/cloud/dataproc_v1/proto/autoscaling_policies.proto", + "google/cloud/dataproc_v1/proto/clusters.proto", + "google/cloud/dataproc_v1/proto/jobs.proto", + "google/cloud/dataproc_v1/proto/operations.proto", + "google/cloud/dataproc_v1/proto/shared.proto", + "google/cloud/dataproc_v1/proto/workflow_templates.proto", + "google/cloud/dataproc_v1/py.typed", + "google/cloud/dataproc_v1/services/__init__.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/__init__.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/async_client.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/client.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/pagers.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/__init__.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/base.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc.py", + "google/cloud/dataproc_v1/services/autoscaling_policy_service/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1/services/cluster_controller/__init__.py", + "google/cloud/dataproc_v1/services/cluster_controller/async_client.py", + "google/cloud/dataproc_v1/services/cluster_controller/client.py", + "google/cloud/dataproc_v1/services/cluster_controller/pagers.py", + "google/cloud/dataproc_v1/services/cluster_controller/transports/__init__.py", + "google/cloud/dataproc_v1/services/cluster_controller/transports/base.py", + "google/cloud/dataproc_v1/services/cluster_controller/transports/grpc.py", + "google/cloud/dataproc_v1/services/cluster_controller/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1/services/job_controller/__init__.py", + "google/cloud/dataproc_v1/services/job_controller/async_client.py", + "google/cloud/dataproc_v1/services/job_controller/client.py", + "google/cloud/dataproc_v1/services/job_controller/pagers.py", + "google/cloud/dataproc_v1/services/job_controller/transports/__init__.py", + "google/cloud/dataproc_v1/services/job_controller/transports/base.py", + "google/cloud/dataproc_v1/services/job_controller/transports/grpc.py", + "google/cloud/dataproc_v1/services/job_controller/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1/services/workflow_template_service/__init__.py", + "google/cloud/dataproc_v1/services/workflow_template_service/async_client.py", + "google/cloud/dataproc_v1/services/workflow_template_service/client.py", + "google/cloud/dataproc_v1/services/workflow_template_service/pagers.py", + "google/cloud/dataproc_v1/services/workflow_template_service/transports/__init__.py", + "google/cloud/dataproc_v1/services/workflow_template_service/transports/base.py", + "google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc.py", + "google/cloud/dataproc_v1/services/workflow_template_service/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1/types/__init__.py", + "google/cloud/dataproc_v1/types/autoscaling_policies.py", + "google/cloud/dataproc_v1/types/clusters.py", + "google/cloud/dataproc_v1/types/jobs.py", + "google/cloud/dataproc_v1/types/operations.py", + "google/cloud/dataproc_v1/types/shared.py", + "google/cloud/dataproc_v1/types/workflow_templates.py", + "google/cloud/dataproc_v1beta2/__init__.py", + "google/cloud/dataproc_v1beta2/proto/autoscaling_policies.proto", + "google/cloud/dataproc_v1beta2/proto/clusters.proto", + "google/cloud/dataproc_v1beta2/proto/jobs.proto", + "google/cloud/dataproc_v1beta2/proto/operations.proto", + "google/cloud/dataproc_v1beta2/proto/shared.proto", + "google/cloud/dataproc_v1beta2/proto/workflow_templates.proto", + "google/cloud/dataproc_v1beta2/py.typed", + "google/cloud/dataproc_v1beta2/services/__init__.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/__init__.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/async_client.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/client.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/pagers.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/__init__.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/base.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc.py", + "google/cloud/dataproc_v1beta2/services/autoscaling_policy_service/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/__init__.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/async_client.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/client.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/pagers.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/transports/__init__.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/transports/base.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc.py", + "google/cloud/dataproc_v1beta2/services/cluster_controller/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1beta2/services/job_controller/__init__.py", + "google/cloud/dataproc_v1beta2/services/job_controller/async_client.py", + "google/cloud/dataproc_v1beta2/services/job_controller/client.py", + "google/cloud/dataproc_v1beta2/services/job_controller/pagers.py", + "google/cloud/dataproc_v1beta2/services/job_controller/transports/__init__.py", + "google/cloud/dataproc_v1beta2/services/job_controller/transports/base.py", + "google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc.py", + "google/cloud/dataproc_v1beta2/services/job_controller/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/__init__.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/async_client.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/client.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/pagers.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/__init__.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/base.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc.py", + "google/cloud/dataproc_v1beta2/services/workflow_template_service/transports/grpc_asyncio.py", + "google/cloud/dataproc_v1beta2/types/__init__.py", + "google/cloud/dataproc_v1beta2/types/autoscaling_policies.py", + "google/cloud/dataproc_v1beta2/types/clusters.py", + "google/cloud/dataproc_v1beta2/types/jobs.py", + "google/cloud/dataproc_v1beta2/types/operations.py", + "google/cloud/dataproc_v1beta2/types/shared.py", + "google/cloud/dataproc_v1beta2/types/workflow_templates.py", + "mypy.ini", + "noxfile.py", + "renovate.json", + "samples/AUTHORING_GUIDE.md", + "samples/CONTRIBUTING.md", + "samples/snippets/noxfile.py", + "scripts/decrypt-secrets.sh", + "scripts/fixup_dataproc_v1_keywords.py", + "scripts/fixup_dataproc_v1beta2_keywords.py", + "scripts/readme-gen/readme_gen.py", + "scripts/readme-gen/templates/README.tmpl.rst", + "scripts/readme-gen/templates/auth.tmpl.rst", + "scripts/readme-gen/templates/auth_api_key.tmpl.rst", + "scripts/readme-gen/templates/install_deps.tmpl.rst", + "scripts/readme-gen/templates/install_portaudio.tmpl.rst", + "setup.cfg", + "testing/.gitignore", + "tests/unit/gapic/dataproc_v1/__init__.py", + "tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py", + "tests/unit/gapic/dataproc_v1/test_cluster_controller.py", + "tests/unit/gapic/dataproc_v1/test_job_controller.py", + "tests/unit/gapic/dataproc_v1/test_workflow_template_service.py", + "tests/unit/gapic/dataproc_v1beta2/__init__.py", + "tests/unit/gapic/dataproc_v1beta2/test_autoscaling_policy_service.py", + "tests/unit/gapic/dataproc_v1beta2/test_cluster_controller.py", + "tests/unit/gapic/dataproc_v1beta2/test_job_controller.py", + "tests/unit/gapic/dataproc_v1beta2/test_workflow_template_service.py" ] } \ No newline at end of file diff --git a/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py b/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py index bf660db4..dab06a53 100644 --- a/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py +++ b/tests/unit/gapic/dataproc_v1/test_autoscaling_policy_service.py @@ -101,12 +101,12 @@ def test_autoscaling_policy_service_client_from_service_account_file(client_clas ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_autoscaling_policy_service_client_get_transport_class(): @@ -170,14 +170,14 @@ def test_autoscaling_policy_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -186,14 +186,14 @@ def test_autoscaling_policy_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -202,90 +202,185 @@ def test_autoscaling_policy_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceGrpcTransport, + "grpc", + "true", + ), + ( + AutoscalingPolicyServiceAsyncClient, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceGrpcTransport, + "grpc", + "false", + ), + ( + AutoscalingPolicyServiceAsyncClient, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + AutoscalingPolicyServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AutoscalingPolicyServiceClient), +) +@mock.patch.object( + AutoscalingPolicyServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AutoscalingPolicyServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_autoscaling_policy_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -316,9 +411,9 @@ def test_autoscaling_policy_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -350,9 +445,9 @@ def test_autoscaling_policy_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -369,9 +464,9 @@ def test_autoscaling_policy_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -389,7 +484,7 @@ def test_create_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy( @@ -411,6 +506,7 @@ def test_create_autoscaling_policy( assert args[0] == autoscaling_policies.CreateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) assert response.id == "id_value" @@ -423,18 +519,21 @@ def test_create_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_create_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_create_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.CreateAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.CreateAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -447,7 +546,7 @@ async def test_create_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.CreateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert isinstance(response, autoscaling_policies.AutoscalingPolicy) @@ -457,6 +556,11 @@ async def test_create_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert response.name == "name_value" +@pytest.mark.asyncio +async def test_create_autoscaling_policy_async_from_dict(): + await test_create_autoscaling_policy_async(request_type=dict) + + def test_create_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -469,7 +573,7 @@ def test_create_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -498,7 +602,7 @@ async def test_create_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.AutoscalingPolicy() @@ -523,7 +627,7 @@ def test_create_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -568,7 +672,7 @@ async def test_create_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -623,7 +727,7 @@ def test_update_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy( @@ -645,6 +749,7 @@ def test_update_autoscaling_policy( assert args[0] == autoscaling_policies.UpdateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) assert response.id == "id_value" @@ -657,18 +762,21 @@ def test_update_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_update_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_update_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.UpdateAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.UpdateAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -681,7 +789,7 @@ async def test_update_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.UpdateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert isinstance(response, autoscaling_policies.AutoscalingPolicy) @@ -691,6 +799,11 @@ async def test_update_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert response.name == "name_value" +@pytest.mark.asyncio +async def test_update_autoscaling_policy_async_from_dict(): + await test_update_autoscaling_policy_async(request_type=dict) + + def test_update_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -703,7 +816,7 @@ def test_update_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -732,7 +845,7 @@ async def test_update_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.AutoscalingPolicy() @@ -757,7 +870,7 @@ def test_update_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -798,7 +911,7 @@ async def test_update_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -849,7 +962,7 @@ def test_get_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy( @@ -871,6 +984,7 @@ def test_get_autoscaling_policy( assert args[0] == autoscaling_policies.GetAutoscalingPolicyRequest() # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) assert response.id == "id_value" @@ -883,18 +997,21 @@ def test_get_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_get_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_get_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.GetAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.GetAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -907,7 +1024,7 @@ async def test_get_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.GetAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert isinstance(response, autoscaling_policies.AutoscalingPolicy) @@ -917,6 +1034,11 @@ async def test_get_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert response.name == "name_value" +@pytest.mark.asyncio +async def test_get_autoscaling_policy_async_from_dict(): + await test_get_autoscaling_policy_async(request_type=dict) + + def test_get_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -929,7 +1051,7 @@ def test_get_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -958,7 +1080,7 @@ async def test_get_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.AutoscalingPolicy() @@ -983,7 +1105,7 @@ def test_get_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -1021,7 +1143,7 @@ async def test_get_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -1069,7 +1191,7 @@ def test_list_autoscaling_policies( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse( @@ -1085,6 +1207,7 @@ def test_list_autoscaling_policies( assert args[0] == autoscaling_policies.ListAutoscalingPoliciesRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAutoscalingPoliciesPager) assert response.next_page_token == "next_page_token_value" @@ -1095,18 +1218,21 @@ def test_list_autoscaling_policies_from_dict(): @pytest.mark.asyncio -async def test_list_autoscaling_policies_async(transport: str = "grpc_asyncio"): +async def test_list_autoscaling_policies_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.ListAutoscalingPoliciesRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.ListAutoscalingPoliciesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1121,7 +1247,7 @@ async def test_list_autoscaling_policies_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.ListAutoscalingPoliciesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAutoscalingPoliciesAsyncPager) @@ -1129,6 +1255,11 @@ async def test_list_autoscaling_policies_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_autoscaling_policies_async_from_dict(): + await test_list_autoscaling_policies_async(request_type=dict) + + def test_list_autoscaling_policies_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1141,7 +1272,7 @@ def test_list_autoscaling_policies_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1170,7 +1301,7 @@ async def test_list_autoscaling_policies_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1195,7 +1326,7 @@ def test_list_autoscaling_policies_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1234,7 +1365,7 @@ async def test_list_autoscaling_policies_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1276,7 +1407,7 @@ def test_list_autoscaling_policies_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1326,7 +1457,7 @@ def test_list_autoscaling_policies_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1354,8 +1485,8 @@ def test_list_autoscaling_policies_pages(): RuntimeError, ) pages = list(client.list_autoscaling_policies(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1366,7 +1497,7 @@ async def test_list_autoscaling_policies_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), + type(client.transport.list_autoscaling_policies), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1415,7 +1546,7 @@ async def test_list_autoscaling_policies_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), + type(client.transport.list_autoscaling_policies), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1445,10 +1576,10 @@ async def test_list_autoscaling_policies_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_autoscaling_policies(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_autoscaling_policies(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_delete_autoscaling_policy( @@ -1465,7 +1596,7 @@ def test_delete_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1487,18 +1618,21 @@ def test_delete_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_delete_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_delete_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.DeleteAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.DeleteAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1509,12 +1643,17 @@ async def test_delete_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.DeleteAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_delete_autoscaling_policy_async_from_dict(): + await test_delete_autoscaling_policy_async(request_type=dict) + + def test_delete_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1527,7 +1666,7 @@ def test_delete_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: call.return_value = None @@ -1556,7 +1695,7 @@ async def test_delete_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1579,7 +1718,7 @@ def test_delete_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1617,7 +1756,7 @@ async def test_delete_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1685,7 +1824,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = AutoscalingPolicyServiceClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -1703,13 +1842,28 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), ) assert isinstance( - client._transport, transports.AutoscalingPolicyServiceGrpcTransport, + client.transport, transports.AutoscalingPolicyServiceGrpcTransport, ) @@ -1765,6 +1919,17 @@ def test_autoscaling_policy_service_base_transport_with_credentials_file(): ) +def test_autoscaling_policy_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1.services.autoscaling_policy_service.transports.AutoscalingPolicyServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.AutoscalingPolicyServiceTransport() + adc.assert_called_once() + + def test_autoscaling_policy_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -1797,7 +1962,7 @@ def test_autoscaling_policy_service_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_autoscaling_policy_service_host_with_port(): @@ -1807,185 +1972,119 @@ def test_autoscaling_policy_service_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_autoscaling_policy_service_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.AutoscalingPolicyServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_autoscaling_policy_service_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.AutoscalingPolicyServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint +def test_autoscaling_policy_service_transport_channel_mtls_with_client_cert_source( + transport_class, ): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.AutoscalingPolicyServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_autoscaling_policy_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_autoscaling_policy_path(): @@ -2013,3 +2112,125 @@ def test_parse_autoscaling_policy_path(): # Check that the path construction is reversible. actual = AutoscalingPolicyServiceClient.parse_autoscaling_policy_path(path) assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = AutoscalingPolicyServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = AutoscalingPolicyServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = AutoscalingPolicyServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = AutoscalingPolicyServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = AutoscalingPolicyServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = AutoscalingPolicyServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = AutoscalingPolicyServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = AutoscalingPolicyServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = AutoscalingPolicyServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = AutoscalingPolicyServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.AutoscalingPolicyServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = AutoscalingPolicyServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.AutoscalingPolicyServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = AutoscalingPolicyServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/dataproc_v1/test_cluster_controller.py b/tests/unit/gapic/dataproc_v1/test_cluster_controller.py index 072ac8ff..c9d47b54 100644 --- a/tests/unit/gapic/dataproc_v1/test_cluster_controller.py +++ b/tests/unit/gapic/dataproc_v1/test_cluster_controller.py @@ -31,7 +31,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async -from google.api_core import operation_async +from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError @@ -42,7 +42,6 @@ from google.cloud.dataproc_v1.services.cluster_controller import pagers from google.cloud.dataproc_v1.services.cluster_controller import transports from google.cloud.dataproc_v1.types import clusters -from google.cloud.dataproc_v1.types import clusters as gcd_clusters from google.cloud.dataproc_v1.types import operations from google.cloud.dataproc_v1.types import shared from google.longrunning import operations_pb2 @@ -107,12 +106,12 @@ def test_cluster_controller_client_from_service_account_file(client_class): ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_cluster_controller_client_get_transport_class(): @@ -168,14 +167,14 @@ def test_cluster_controller_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -184,14 +183,14 @@ def test_cluster_controller_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -200,90 +199,185 @@ def test_cluster_controller_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + ClusterControllerClient, + transports.ClusterControllerGrpcTransport, + "grpc", + "true", + ), + ( + ClusterControllerAsyncClient, + transports.ClusterControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + ClusterControllerClient, + transports.ClusterControllerGrpcTransport, + "grpc", + "false", + ), + ( + ClusterControllerAsyncClient, + transports.ClusterControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ClusterControllerClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ClusterControllerClient), +) +@mock.patch.object( + ClusterControllerAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ClusterControllerAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_cluster_controller_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -310,9 +404,9 @@ def test_cluster_controller_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -340,9 +434,9 @@ def test_cluster_controller_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -359,9 +453,9 @@ def test_cluster_controller_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -377,7 +471,7 @@ def test_create_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -398,19 +492,19 @@ def test_create_cluster_from_dict(): @pytest.mark.asyncio -async def test_create_cluster_async(transport: str = "grpc_asyncio"): +async def test_create_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.CreateClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.CreateClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -422,17 +516,22 @@ async def test_create_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.CreateClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_create_cluster_async_from_dict(): + await test_create_cluster_async(request_type=dict) + + def test_create_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -477,9 +576,7 @@ async def test_create_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -535,7 +632,7 @@ def test_update_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -556,19 +653,19 @@ def test_update_cluster_from_dict(): @pytest.mark.asyncio -async def test_update_cluster_async(transport: str = "grpc_asyncio"): +async def test_update_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.UpdateClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.UpdateClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -580,17 +677,22 @@ async def test_update_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.UpdateClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_update_cluster_async_from_dict(): + await test_update_cluster_async(request_type=dict) + + def test_update_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -643,9 +745,7 @@ async def test_update_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -709,7 +809,7 @@ def test_delete_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -730,19 +830,19 @@ def test_delete_cluster_from_dict(): @pytest.mark.asyncio -async def test_delete_cluster_async(transport: str = "grpc_asyncio"): +async def test_delete_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.DeleteClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.DeleteClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -754,17 +854,22 @@ async def test_delete_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.DeleteClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_cluster_async_from_dict(): + await test_delete_cluster_async(request_type=dict) + + def test_delete_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -809,9 +914,7 @@ async def test_delete_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -865,7 +968,7 @@ def test_get_cluster(transport: str = "grpc", request_type=clusters.GetClusterRe request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.Cluster( project_id="project_id_value", @@ -882,6 +985,7 @@ def test_get_cluster(transport: str = "grpc", request_type=clusters.GetClusterRe assert args[0] == clusters.GetClusterRequest() # Establish that the response is the type that we expect. + assert isinstance(response, clusters.Cluster) assert response.project_id == "project_id_value" @@ -896,19 +1000,19 @@ def test_get_cluster_from_dict(): @pytest.mark.asyncio -async def test_get_cluster_async(transport: str = "grpc_asyncio"): +async def test_get_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.GetClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.GetClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( clusters.Cluster( @@ -924,7 +1028,7 @@ async def test_get_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.GetClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, clusters.Cluster) @@ -936,11 +1040,16 @@ async def test_get_cluster_async(transport: str = "grpc_asyncio"): assert response.cluster_uuid == "cluster_uuid_value" +@pytest.mark.asyncio +async def test_get_cluster_async_from_dict(): + await test_get_cluster_async(request_type=dict) + + def test_get_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.Cluster() @@ -985,9 +1094,7 @@ async def test_get_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.Cluster() @@ -1041,7 +1148,7 @@ def test_list_clusters( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.ListClustersResponse( next_page_token="next_page_token_value", @@ -1056,6 +1163,7 @@ def test_list_clusters( assert args[0] == clusters.ListClustersRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListClustersPager) assert response.next_page_token == "next_page_token_value" @@ -1066,19 +1174,19 @@ def test_list_clusters_from_dict(): @pytest.mark.asyncio -async def test_list_clusters_async(transport: str = "grpc_asyncio"): +async def test_list_clusters_async( + transport: str = "grpc_asyncio", request_type=clusters.ListClustersRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.ListClustersRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_clusters), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( clusters.ListClustersResponse(next_page_token="next_page_token_value",) @@ -1090,7 +1198,7 @@ async def test_list_clusters_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.ListClustersRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListClustersAsyncPager) @@ -1098,11 +1206,16 @@ async def test_list_clusters_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_clusters_async_from_dict(): + await test_list_clusters_async(request_type=dict) + + def test_list_clusters_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.ListClustersResponse() @@ -1145,9 +1258,7 @@ async def test_list_clusters_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_clusters), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.ListClustersResponse() @@ -1193,7 +1304,7 @@ def test_list_clusters_pager(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( clusters.ListClustersResponse( @@ -1224,7 +1335,7 @@ def test_list_clusters_pages(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( clusters.ListClustersResponse( @@ -1241,8 +1352,8 @@ def test_list_clusters_pages(): RuntimeError, ) pages = list(client.list_clusters(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1251,9 +1362,7 @@ async def test_list_clusters_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_clusters), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_clusters), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1286,9 +1395,7 @@ async def test_list_clusters_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_clusters), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_clusters), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1306,10 +1413,10 @@ async def test_list_clusters_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_clusters(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_clusters(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_diagnose_cluster( @@ -1324,9 +1431,7 @@ def test_diagnose_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -1347,19 +1452,19 @@ def test_diagnose_cluster_from_dict(): @pytest.mark.asyncio -async def test_diagnose_cluster_async(transport: str = "grpc_asyncio"): +async def test_diagnose_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.DiagnoseClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.DiagnoseClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -1371,19 +1476,22 @@ async def test_diagnose_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.DiagnoseClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_diagnose_cluster_async_from_dict(): + await test_diagnose_cluster_async(request_type=dict) + + def test_diagnose_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1428,9 +1536,7 @@ async def test_diagnose_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1510,7 +1616,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = ClusterControllerClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -1528,10 +1634,25 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.ClusterControllerGrpcTransport, + transports.ClusterControllerGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.ClusterControllerGrpcTransport,) + assert isinstance(client.transport, transports.ClusterControllerGrpcTransport,) def test_cluster_controller_base_transport_error(): @@ -1592,6 +1713,17 @@ def test_cluster_controller_base_transport_with_credentials_file(): ) +def test_cluster_controller_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1.services.cluster_controller.transports.ClusterControllerTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ClusterControllerTransport() + adc.assert_called_once() + + def test_cluster_controller_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -1624,7 +1756,7 @@ def test_cluster_controller_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_cluster_controller_host_with_port(): @@ -1634,192 +1766,126 @@ def test_cluster_controller_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_cluster_controller_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.ClusterControllerGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_cluster_controller_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.ClusterControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_cluster_controller_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.ClusterControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_cluster_controller_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.ClusterControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.ClusterControllerGrpcTransport, + transports.ClusterControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_cluster_controller_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint +def test_cluster_controller_transport_channel_mtls_with_client_cert_source( + transport_class, ): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.ClusterControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.ClusterControllerGrpcTransport, + transports.ClusterControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_cluster_controller_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_cluster_controller_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.ClusterControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_cluster_controller_grpc_lro_client(): client = ClusterControllerClient( credentials=credentials.AnonymousCredentials(), transport="grpc", ) - transport = client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsClient,) @@ -1832,10 +1898,132 @@ def test_cluster_controller_grpc_lro_async_client(): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - transport = client._client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = ClusterControllerClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = ClusterControllerClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = ClusterControllerClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = ClusterControllerClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = ClusterControllerClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = ClusterControllerClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = ClusterControllerClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = ClusterControllerClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = ClusterControllerClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = ClusterControllerClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.ClusterControllerTransport, "_prep_wrapped_messages" + ) as prep: + client = ClusterControllerClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.ClusterControllerTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = ClusterControllerClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/dataproc_v1/test_job_controller.py b/tests/unit/gapic/dataproc_v1/test_job_controller.py index 83403e9f..b19332a4 100644 --- a/tests/unit/gapic/dataproc_v1/test_job_controller.py +++ b/tests/unit/gapic/dataproc_v1/test_job_controller.py @@ -31,7 +31,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async -from google.api_core import operation_async +from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError @@ -40,7 +40,6 @@ from google.cloud.dataproc_v1.services.job_controller import pagers from google.cloud.dataproc_v1.services.job_controller import transports from google.cloud.dataproc_v1.types import jobs -from google.cloud.dataproc_v1.types import jobs as gcd_jobs from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import field_mask_pb2 as field_mask # type: ignore @@ -101,12 +100,12 @@ def test_job_controller_client_from_service_account_file(client_class): ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_job_controller_client_get_transport_class(): @@ -162,14 +161,14 @@ def test_job_controller_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -178,14 +177,14 @@ def test_job_controller_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -194,90 +193,175 @@ def test_job_controller_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobControllerClient, transports.JobControllerGrpcTransport, "grpc", "true"), + ( + JobControllerAsyncClient, + transports.JobControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobControllerClient, transports.JobControllerGrpcTransport, "grpc", "false"), + ( + JobControllerAsyncClient, + transports.JobControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobControllerClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobControllerClient), +) +@mock.patch.object( + JobControllerAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobControllerAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_job_controller_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -304,9 +388,9 @@ def test_job_controller_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -334,9 +418,9 @@ def test_job_controller_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -353,9 +437,9 @@ def test_job_controller_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -369,7 +453,7 @@ def test_submit_job(transport: str = "grpc", request_type=jobs.SubmitJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.submit_job), "__call__") as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( driver_output_resource_uri="driver_output_resource_uri_value", @@ -388,6 +472,7 @@ def test_submit_job(transport: str = "grpc", request_type=jobs.SubmitJobRequest) assert args[0] == jobs.SubmitJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.driver_output_resource_uri == "driver_output_resource_uri_value" @@ -404,19 +489,19 @@ def test_submit_job_from_dict(): @pytest.mark.asyncio -async def test_submit_job_async(transport: str = "grpc_asyncio"): +async def test_submit_job_async( + transport: str = "grpc_asyncio", request_type=jobs.SubmitJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.SubmitJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.submit_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -433,7 +518,7 @@ async def test_submit_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.SubmitJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -447,11 +532,16 @@ async def test_submit_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_submit_job_async_from_dict(): + await test_submit_job_async(request_type=dict) + + def test_submit_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.submit_job), "__call__") as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -496,9 +586,7 @@ async def test_submit_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.submit_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -553,7 +641,7 @@ def test_submit_job_as_operation( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -575,18 +663,20 @@ def test_submit_job_as_operation_from_dict(): @pytest.mark.asyncio -async def test_submit_job_as_operation_async(transport: str = "grpc_asyncio"): +async def test_submit_job_as_operation_async( + transport: str = "grpc_asyncio", request_type=jobs.SubmitJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.SubmitJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -599,18 +689,23 @@ async def test_submit_job_as_operation_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.SubmitJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_submit_job_as_operation_async_from_dict(): + await test_submit_job_as_operation_async(request_type=dict) + + def test_submit_job_as_operation_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -657,7 +752,7 @@ async def test_submit_job_as_operation_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -712,7 +807,7 @@ def test_get_job(transport: str = "grpc", request_type=jobs.GetJobRequest): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( driver_output_resource_uri="driver_output_resource_uri_value", @@ -731,6 +826,7 @@ def test_get_job(transport: str = "grpc", request_type=jobs.GetJobRequest): assert args[0] == jobs.GetJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.driver_output_resource_uri == "driver_output_resource_uri_value" @@ -747,17 +843,19 @@ def test_get_job_from_dict(): @pytest.mark.asyncio -async def test_get_job_async(transport: str = "grpc_asyncio"): +async def test_get_job_async( + transport: str = "grpc_asyncio", request_type=jobs.GetJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.GetJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -774,7 +872,7 @@ async def test_get_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.GetJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -788,11 +886,16 @@ async def test_get_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_get_job_async_from_dict(): + await test_get_job_async(request_type=dict) + + def test_get_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -833,7 +936,7 @@ async def test_get_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -881,7 +984,7 @@ def test_list_jobs(transport: str = "grpc", request_type=jobs.ListJobsRequest): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.ListJobsResponse( next_page_token="next_page_token_value", @@ -896,6 +999,7 @@ def test_list_jobs(transport: str = "grpc", request_type=jobs.ListJobsRequest): assert args[0] == jobs.ListJobsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListJobsPager) assert response.next_page_token == "next_page_token_value" @@ -906,19 +1010,19 @@ def test_list_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_jobs_async(transport: str = "grpc_asyncio"): +async def test_list_jobs_async( + transport: str = "grpc_asyncio", request_type=jobs.ListJobsRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.ListJobsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_jobs), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.ListJobsResponse(next_page_token="next_page_token_value",) @@ -930,7 +1034,7 @@ async def test_list_jobs_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.ListJobsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListJobsAsyncPager) @@ -938,11 +1042,16 @@ async def test_list_jobs_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_jobs_async_from_dict(): + await test_list_jobs_async(request_type=dict) + + def test_list_jobs_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.ListJobsResponse() @@ -983,9 +1092,7 @@ async def test_list_jobs_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_jobs), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.ListJobsResponse() @@ -1029,7 +1136,7 @@ def test_list_jobs_pager(): client = JobControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( jobs.ListJobsResponse( @@ -1055,7 +1162,7 @@ def test_list_jobs_pages(): client = JobControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( jobs.ListJobsResponse( @@ -1067,8 +1174,8 @@ def test_list_jobs_pages(): RuntimeError, ) pages = list(client.list_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1077,9 +1184,7 @@ async def test_list_jobs_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_jobs), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_jobs), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1107,9 +1212,7 @@ async def test_list_jobs_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_jobs), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_jobs), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1122,10 +1225,10 @@ async def test_list_jobs_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_jobs(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_update_job(transport: str = "grpc", request_type=jobs.UpdateJobRequest): @@ -1138,7 +1241,7 @@ def test_update_job(transport: str = "grpc", request_type=jobs.UpdateJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_job), "__call__") as call: + with mock.patch.object(type(client.transport.update_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( driver_output_resource_uri="driver_output_resource_uri_value", @@ -1157,6 +1260,7 @@ def test_update_job(transport: str = "grpc", request_type=jobs.UpdateJobRequest) assert args[0] == jobs.UpdateJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.driver_output_resource_uri == "driver_output_resource_uri_value" @@ -1173,19 +1277,19 @@ def test_update_job_from_dict(): @pytest.mark.asyncio -async def test_update_job_async(transport: str = "grpc_asyncio"): +async def test_update_job_async( + transport: str = "grpc_asyncio", request_type=jobs.UpdateJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.UpdateJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.update_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -1202,7 +1306,7 @@ async def test_update_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.UpdateJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -1216,6 +1320,11 @@ async def test_update_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_update_job_async_from_dict(): + await test_update_job_async(request_type=dict) + + def test_cancel_job(transport: str = "grpc", request_type=jobs.CancelJobRequest): client = JobControllerClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1226,7 +1335,7 @@ def test_cancel_job(transport: str = "grpc", request_type=jobs.CancelJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.cancel_job), "__call__") as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( driver_output_resource_uri="driver_output_resource_uri_value", @@ -1245,6 +1354,7 @@ def test_cancel_job(transport: str = "grpc", request_type=jobs.CancelJobRequest) assert args[0] == jobs.CancelJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.driver_output_resource_uri == "driver_output_resource_uri_value" @@ -1261,19 +1371,19 @@ def test_cancel_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_job_async(transport: str = "grpc_asyncio"): +async def test_cancel_job_async( + transport: str = "grpc_asyncio", request_type=jobs.CancelJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.CancelJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -1290,7 +1400,7 @@ async def test_cancel_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.CancelJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -1304,11 +1414,16 @@ async def test_cancel_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_cancel_job_async_from_dict(): + await test_cancel_job_async(request_type=dict) + + def test_cancel_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.cancel_job), "__call__") as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -1349,9 +1464,7 @@ async def test_cancel_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -1399,7 +1512,7 @@ def test_delete_job(transport: str = "grpc", request_type=jobs.DeleteJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_job), "__call__") as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1420,19 +1533,19 @@ def test_delete_job_from_dict(): @pytest.mark.asyncio -async def test_delete_job_async(transport: str = "grpc_asyncio"): +async def test_delete_job_async( + transport: str = "grpc_asyncio", request_type=jobs.DeleteJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.DeleteJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1442,17 +1555,22 @@ async def test_delete_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.DeleteJobRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_delete_job_async_from_dict(): + await test_delete_job_async(request_type=dict) + + def test_delete_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_job), "__call__") as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1493,9 +1611,7 @@ async def test_delete_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1569,7 +1685,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = JobControllerClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -1587,10 +1703,25 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.JobControllerGrpcTransport, + transports.JobControllerGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.JobControllerGrpcTransport,) + assert isinstance(client.transport, transports.JobControllerGrpcTransport,) def test_job_controller_base_transport_error(): @@ -1652,6 +1783,17 @@ def test_job_controller_base_transport_with_credentials_file(): ) +def test_job_controller_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1.services.job_controller.transports.JobControllerTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobControllerTransport() + adc.assert_called_once() + + def test_job_controller_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -1684,7 +1826,7 @@ def test_job_controller_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_job_controller_host_with_port(): @@ -1694,192 +1836,124 @@ def test_job_controller_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_job_controller_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.JobControllerGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_job_controller_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.JobControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_job_controller_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.JobControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_job_controller_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.JobControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.JobControllerGrpcTransport, + transports.JobControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_job_controller_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.JobControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel +def test_job_controller_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.JobControllerGrpcTransport, + transports.JobControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_job_controller_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_job_controller_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.JobControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_job_controller_grpc_lro_client(): client = JobControllerClient( credentials=credentials.AnonymousCredentials(), transport="grpc", ) - transport = client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsClient,) @@ -1892,10 +1966,132 @@ def test_job_controller_grpc_lro_async_client(): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - transport = client._client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = JobControllerClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = JobControllerClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = JobControllerClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = JobControllerClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = JobControllerClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = JobControllerClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = JobControllerClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = JobControllerClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = JobControllerClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = JobControllerClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.JobControllerTransport, "_prep_wrapped_messages" + ) as prep: + client = JobControllerClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.JobControllerTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = JobControllerClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py b/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py index c88ec4ab..be036496 100644 --- a/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py +++ b/tests/unit/gapic/dataproc_v1/test_workflow_template_service.py @@ -31,7 +31,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async -from google.api_core import operation_async +from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError @@ -44,9 +44,7 @@ from google.cloud.dataproc_v1.services.workflow_template_service import pagers from google.cloud.dataproc_v1.services.workflow_template_service import transports from google.cloud.dataproc_v1.types import clusters -from google.cloud.dataproc_v1.types import clusters as gcd_clusters from google.cloud.dataproc_v1.types import jobs -from google.cloud.dataproc_v1.types import jobs as gcd_jobs from google.cloud.dataproc_v1.types import shared from google.cloud.dataproc_v1.types import workflow_templates from google.longrunning import operations_pb2 @@ -110,12 +108,12 @@ def test_workflow_template_service_client_from_service_account_file(client_class ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_workflow_template_service_client_get_transport_class(): @@ -175,14 +173,14 @@ def test_workflow_template_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -191,14 +189,14 @@ def test_workflow_template_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -207,90 +205,185 @@ def test_workflow_template_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceGrpcTransport, + "grpc", + "true", + ), + ( + WorkflowTemplateServiceAsyncClient, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceGrpcTransport, + "grpc", + "false", + ), + ( + WorkflowTemplateServiceAsyncClient, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + WorkflowTemplateServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(WorkflowTemplateServiceClient), +) +@mock.patch.object( + WorkflowTemplateServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(WorkflowTemplateServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_workflow_template_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -321,9 +414,9 @@ def test_workflow_template_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -355,9 +448,9 @@ def test_workflow_template_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -374,9 +467,9 @@ def test_workflow_template_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +487,7 @@ def test_create_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate( @@ -410,6 +503,7 @@ def test_create_workflow_template( assert args[0] == workflow_templates.CreateWorkflowTemplateRequest() # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) assert response.id == "id_value" @@ -424,18 +518,21 @@ def test_create_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_create_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_create_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.CreateWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.CreateWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -450,7 +547,7 @@ async def test_create_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.CreateWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, workflow_templates.WorkflowTemplate) @@ -462,6 +559,11 @@ async def test_create_workflow_template_async(transport: str = "grpc_asyncio"): assert response.version == 774 +@pytest.mark.asyncio +async def test_create_workflow_template_async_from_dict(): + await test_create_workflow_template_async(request_type=dict) + + def test_create_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -474,7 +576,7 @@ def test_create_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: call.return_value = workflow_templates.WorkflowTemplate() @@ -503,7 +605,7 @@ async def test_create_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.WorkflowTemplate() @@ -528,7 +630,7 @@ def test_create_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -573,7 +675,7 @@ async def test_create_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -627,7 +729,7 @@ def test_get_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate( @@ -643,6 +745,7 @@ def test_get_workflow_template( assert args[0] == workflow_templates.GetWorkflowTemplateRequest() # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) assert response.id == "id_value" @@ -657,18 +760,21 @@ def test_get_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_get_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_get_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.GetWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.GetWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -683,7 +789,7 @@ async def test_get_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.GetWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, workflow_templates.WorkflowTemplate) @@ -695,6 +801,11 @@ async def test_get_workflow_template_async(transport: str = "grpc_asyncio"): assert response.version == 774 +@pytest.mark.asyncio +async def test_get_workflow_template_async_from_dict(): + await test_get_workflow_template_async(request_type=dict) + + def test_get_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -707,7 +818,7 @@ def test_get_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: call.return_value = workflow_templates.WorkflowTemplate() @@ -736,7 +847,7 @@ async def test_get_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.WorkflowTemplate() @@ -761,7 +872,7 @@ def test_get_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -799,7 +910,7 @@ async def test_get_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -847,7 +958,7 @@ def test_instantiate_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -869,18 +980,21 @@ def test_instantiate_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_instantiate_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_instantiate_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.InstantiateWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.InstantiateWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -893,12 +1007,17 @@ async def test_instantiate_workflow_template_async(transport: str = "grpc_asynci assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.InstantiateWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_instantiate_workflow_template_async_from_dict(): + await test_instantiate_workflow_template_async(request_type=dict) + + def test_instantiate_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -911,7 +1030,7 @@ def test_instantiate_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: call.return_value = operations_pb2.Operation(name="operations/op") @@ -940,7 +1059,7 @@ async def test_instantiate_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/op") @@ -965,7 +1084,7 @@ def test_instantiate_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1009,7 +1128,7 @@ async def test_instantiate_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1063,7 +1182,7 @@ def test_instantiate_inline_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -1087,6 +1206,7 @@ def test_instantiate_inline_workflow_template_from_dict(): @pytest.mark.asyncio async def test_instantiate_inline_workflow_template_async( transport: str = "grpc_asyncio", + request_type=workflow_templates.InstantiateInlineWorkflowTemplateRequest, ): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1094,11 +1214,11 @@ async def test_instantiate_inline_workflow_template_async( # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.InstantiateInlineWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1111,12 +1231,17 @@ async def test_instantiate_inline_workflow_template_async( assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.InstantiateInlineWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_instantiate_inline_workflow_template_async_from_dict(): + await test_instantiate_inline_workflow_template_async(request_type=dict) + + def test_instantiate_inline_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1129,7 +1254,7 @@ def test_instantiate_inline_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: call.return_value = operations_pb2.Operation(name="operations/op") @@ -1158,7 +1283,7 @@ async def test_instantiate_inline_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/op") @@ -1183,7 +1308,7 @@ def test_instantiate_inline_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1228,7 +1353,7 @@ async def test_instantiate_inline_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1283,7 +1408,7 @@ def test_update_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate( @@ -1299,6 +1424,7 @@ def test_update_workflow_template( assert args[0] == workflow_templates.UpdateWorkflowTemplateRequest() # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) assert response.id == "id_value" @@ -1313,18 +1439,21 @@ def test_update_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_update_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_update_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.UpdateWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.UpdateWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1339,7 +1468,7 @@ async def test_update_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.UpdateWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, workflow_templates.WorkflowTemplate) @@ -1351,6 +1480,11 @@ async def test_update_workflow_template_async(transport: str = "grpc_asyncio"): assert response.version == 774 +@pytest.mark.asyncio +async def test_update_workflow_template_async_from_dict(): + await test_update_workflow_template_async(request_type=dict) + + def test_update_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1363,7 +1497,7 @@ def test_update_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: call.return_value = workflow_templates.WorkflowTemplate() @@ -1394,7 +1528,7 @@ async def test_update_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.WorkflowTemplate() @@ -1421,7 +1555,7 @@ def test_update_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -1462,7 +1596,7 @@ async def test_update_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -1513,7 +1647,7 @@ def test_list_workflow_templates( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.ListWorkflowTemplatesResponse( @@ -1529,6 +1663,7 @@ def test_list_workflow_templates( assert args[0] == workflow_templates.ListWorkflowTemplatesRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListWorkflowTemplatesPager) assert response.next_page_token == "next_page_token_value" @@ -1539,18 +1674,21 @@ def test_list_workflow_templates_from_dict(): @pytest.mark.asyncio -async def test_list_workflow_templates_async(transport: str = "grpc_asyncio"): +async def test_list_workflow_templates_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.ListWorkflowTemplatesRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.ListWorkflowTemplatesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1565,7 +1703,7 @@ async def test_list_workflow_templates_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.ListWorkflowTemplatesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListWorkflowTemplatesAsyncPager) @@ -1573,6 +1711,11 @@ async def test_list_workflow_templates_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_workflow_templates_async_from_dict(): + await test_list_workflow_templates_async(request_type=dict) + + def test_list_workflow_templates_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1585,7 +1728,7 @@ def test_list_workflow_templates_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: call.return_value = workflow_templates.ListWorkflowTemplatesResponse() @@ -1614,7 +1757,7 @@ async def test_list_workflow_templates_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.ListWorkflowTemplatesResponse() @@ -1639,7 +1782,7 @@ def test_list_workflow_templates_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.ListWorkflowTemplatesResponse() @@ -1677,7 +1820,7 @@ async def test_list_workflow_templates_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.ListWorkflowTemplatesResponse() @@ -1718,7 +1861,7 @@ def test_list_workflow_templates_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1766,7 +1909,7 @@ def test_list_workflow_templates_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1794,8 +1937,8 @@ def test_list_workflow_templates_pages(): RuntimeError, ) pages = list(client.list_workflow_templates(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1806,7 +1949,7 @@ async def test_list_workflow_templates_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), + type(client.transport.list_workflow_templates), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1855,7 +1998,7 @@ async def test_list_workflow_templates_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), + type(client.transport.list_workflow_templates), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1885,10 +2028,10 @@ async def test_list_workflow_templates_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_workflow_templates(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_workflow_templates(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_delete_workflow_template( @@ -1905,7 +2048,7 @@ def test_delete_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1927,18 +2070,21 @@ def test_delete_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_delete_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_delete_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.DeleteWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.DeleteWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1949,12 +2095,17 @@ async def test_delete_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.DeleteWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_delete_workflow_template_async_from_dict(): + await test_delete_workflow_template_async(request_type=dict) + + def test_delete_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1967,7 +2118,7 @@ def test_delete_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: call.return_value = None @@ -1996,7 +2147,7 @@ async def test_delete_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -2019,7 +2170,7 @@ def test_delete_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2057,7 +2208,7 @@ async def test_delete_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2125,7 +2276,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = WorkflowTemplateServiceClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -2143,13 +2294,28 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), ) assert isinstance( - client._transport, transports.WorkflowTemplateServiceGrpcTransport, + client.transport, transports.WorkflowTemplateServiceGrpcTransport, ) @@ -2212,6 +2378,17 @@ def test_workflow_template_service_base_transport_with_credentials_file(): ) +def test_workflow_template_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1.services.workflow_template_service.transports.WorkflowTemplateServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.WorkflowTemplateServiceTransport() + adc.assert_called_once() + + def test_workflow_template_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -2244,7 +2421,7 @@ def test_workflow_template_service_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_workflow_template_service_host_with_port(): @@ -2254,192 +2431,126 @@ def test_workflow_template_service_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_workflow_template_service_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.WorkflowTemplateServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_workflow_template_service_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_workflow_template_service_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.WorkflowTemplateServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_workflow_template_service_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_workflow_template_service_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint +def test_workflow_template_service_transport_channel_mtls_with_client_cert_source( + transport_class, ): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.WorkflowTemplateServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_workflow_template_service_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_workflow_template_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_workflow_template_service_grpc_lro_client(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), transport="grpc", ) - transport = client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsClient,) @@ -2452,7 +2563,7 @@ def test_workflow_template_service_grpc_lro_async_client(): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - transport = client._client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) @@ -2486,3 +2597,125 @@ def test_parse_workflow_template_path(): # Check that the path construction is reversible. actual = WorkflowTemplateServiceClient.parse_workflow_template_path(path) assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = WorkflowTemplateServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = WorkflowTemplateServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = WorkflowTemplateServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = WorkflowTemplateServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = WorkflowTemplateServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = WorkflowTemplateServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = WorkflowTemplateServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = WorkflowTemplateServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = WorkflowTemplateServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = WorkflowTemplateServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.WorkflowTemplateServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = WorkflowTemplateServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.WorkflowTemplateServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = WorkflowTemplateServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/dataproc_v1beta2/test_autoscaling_policy_service.py b/tests/unit/gapic/dataproc_v1beta2/test_autoscaling_policy_service.py index cfe27251..239f2993 100644 --- a/tests/unit/gapic/dataproc_v1beta2/test_autoscaling_policy_service.py +++ b/tests/unit/gapic/dataproc_v1beta2/test_autoscaling_policy_service.py @@ -101,12 +101,12 @@ def test_autoscaling_policy_service_client_from_service_account_file(client_clas ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_autoscaling_policy_service_client_get_transport_class(): @@ -170,14 +170,14 @@ def test_autoscaling_policy_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -186,14 +186,14 @@ def test_autoscaling_policy_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -202,90 +202,185 @@ def test_autoscaling_policy_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceGrpcTransport, + "grpc", + "true", + ), + ( + AutoscalingPolicyServiceAsyncClient, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + AutoscalingPolicyServiceClient, + transports.AutoscalingPolicyServiceGrpcTransport, + "grpc", + "false", + ), + ( + AutoscalingPolicyServiceAsyncClient, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + AutoscalingPolicyServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AutoscalingPolicyServiceClient), +) +@mock.patch.object( + AutoscalingPolicyServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(AutoscalingPolicyServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_autoscaling_policy_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -316,9 +411,9 @@ def test_autoscaling_policy_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -350,9 +445,9 @@ def test_autoscaling_policy_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -369,9 +464,9 @@ def test_autoscaling_policy_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -389,7 +484,7 @@ def test_create_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy( @@ -411,6 +506,7 @@ def test_create_autoscaling_policy( assert args[0] == autoscaling_policies.CreateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) assert response.id == "id_value" @@ -423,18 +519,21 @@ def test_create_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_create_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_create_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.CreateAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.CreateAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -447,7 +546,7 @@ async def test_create_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.CreateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert isinstance(response, autoscaling_policies.AutoscalingPolicy) @@ -457,6 +556,11 @@ async def test_create_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert response.name == "name_value" +@pytest.mark.asyncio +async def test_create_autoscaling_policy_async_from_dict(): + await test_create_autoscaling_policy_async(request_type=dict) + + def test_create_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -469,7 +573,7 @@ def test_create_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -498,7 +602,7 @@ async def test_create_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.AutoscalingPolicy() @@ -523,7 +627,7 @@ def test_create_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -568,7 +672,7 @@ async def test_create_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_autoscaling_policy), "__call__" + type(client.transport.create_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -623,7 +727,7 @@ def test_update_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy( @@ -645,6 +749,7 @@ def test_update_autoscaling_policy( assert args[0] == autoscaling_policies.UpdateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) assert response.id == "id_value" @@ -657,18 +762,21 @@ def test_update_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_update_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_update_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.UpdateAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.UpdateAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -681,7 +789,7 @@ async def test_update_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.UpdateAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert isinstance(response, autoscaling_policies.AutoscalingPolicy) @@ -691,6 +799,11 @@ async def test_update_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert response.name == "name_value" +@pytest.mark.asyncio +async def test_update_autoscaling_policy_async_from_dict(): + await test_update_autoscaling_policy_async(request_type=dict) + + def test_update_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -703,7 +816,7 @@ def test_update_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -732,7 +845,7 @@ async def test_update_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.AutoscalingPolicy() @@ -757,7 +870,7 @@ def test_update_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -798,7 +911,7 @@ async def test_update_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_autoscaling_policy), "__call__" + type(client.transport.update_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -849,7 +962,7 @@ def test_get_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy( @@ -871,6 +984,7 @@ def test_get_autoscaling_policy( assert args[0] == autoscaling_policies.GetAutoscalingPolicyRequest() # Establish that the response is the type that we expect. + assert isinstance(response, autoscaling_policies.AutoscalingPolicy) assert response.id == "id_value" @@ -883,18 +997,21 @@ def test_get_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_get_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_get_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.GetAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.GetAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -907,7 +1024,7 @@ async def test_get_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.GetAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert isinstance(response, autoscaling_policies.AutoscalingPolicy) @@ -917,6 +1034,11 @@ async def test_get_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert response.name == "name_value" +@pytest.mark.asyncio +async def test_get_autoscaling_policy_async_from_dict(): + await test_get_autoscaling_policy_async(request_type=dict) + + def test_get_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -929,7 +1051,7 @@ def test_get_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -958,7 +1080,7 @@ async def test_get_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.AutoscalingPolicy() @@ -983,7 +1105,7 @@ def test_get_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -1021,7 +1143,7 @@ async def test_get_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_autoscaling_policy), "__call__" + type(client.transport.get_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.AutoscalingPolicy() @@ -1069,7 +1191,7 @@ def test_list_autoscaling_policies( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse( @@ -1085,6 +1207,7 @@ def test_list_autoscaling_policies( assert args[0] == autoscaling_policies.ListAutoscalingPoliciesRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAutoscalingPoliciesPager) assert response.next_page_token == "next_page_token_value" @@ -1095,18 +1218,21 @@ def test_list_autoscaling_policies_from_dict(): @pytest.mark.asyncio -async def test_list_autoscaling_policies_async(transport: str = "grpc_asyncio"): +async def test_list_autoscaling_policies_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.ListAutoscalingPoliciesRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.ListAutoscalingPoliciesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1121,7 +1247,7 @@ async def test_list_autoscaling_policies_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.ListAutoscalingPoliciesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListAutoscalingPoliciesAsyncPager) @@ -1129,6 +1255,11 @@ async def test_list_autoscaling_policies_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_autoscaling_policies_async_from_dict(): + await test_list_autoscaling_policies_async(request_type=dict) + + def test_list_autoscaling_policies_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1141,7 +1272,7 @@ def test_list_autoscaling_policies_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1170,7 +1301,7 @@ async def test_list_autoscaling_policies_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1195,7 +1326,7 @@ def test_list_autoscaling_policies_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1234,7 +1365,7 @@ async def test_list_autoscaling_policies_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = autoscaling_policies.ListAutoscalingPoliciesResponse() @@ -1276,7 +1407,7 @@ def test_list_autoscaling_policies_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1326,7 +1457,7 @@ def test_list_autoscaling_policies_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_autoscaling_policies), "__call__" + type(client.transport.list_autoscaling_policies), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1354,8 +1485,8 @@ def test_list_autoscaling_policies_pages(): RuntimeError, ) pages = list(client.list_autoscaling_policies(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1366,7 +1497,7 @@ async def test_list_autoscaling_policies_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), + type(client.transport.list_autoscaling_policies), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1415,7 +1546,7 @@ async def test_list_autoscaling_policies_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_autoscaling_policies), + type(client.transport.list_autoscaling_policies), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1445,10 +1576,10 @@ async def test_list_autoscaling_policies_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_autoscaling_policies(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_autoscaling_policies(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_delete_autoscaling_policy( @@ -1465,7 +1596,7 @@ def test_delete_autoscaling_policy( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1487,18 +1618,21 @@ def test_delete_autoscaling_policy_from_dict(): @pytest.mark.asyncio -async def test_delete_autoscaling_policy_async(transport: str = "grpc_asyncio"): +async def test_delete_autoscaling_policy_async( + transport: str = "grpc_asyncio", + request_type=autoscaling_policies.DeleteAutoscalingPolicyRequest, +): client = AutoscalingPolicyServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = autoscaling_policies.DeleteAutoscalingPolicyRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1509,12 +1643,17 @@ async def test_delete_autoscaling_policy_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == autoscaling_policies.DeleteAutoscalingPolicyRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_delete_autoscaling_policy_async_from_dict(): + await test_delete_autoscaling_policy_async(request_type=dict) + + def test_delete_autoscaling_policy_field_headers(): client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1527,7 +1666,7 @@ def test_delete_autoscaling_policy_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: call.return_value = None @@ -1556,7 +1695,7 @@ async def test_delete_autoscaling_policy_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1579,7 +1718,7 @@ def test_delete_autoscaling_policy_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1617,7 +1756,7 @@ async def test_delete_autoscaling_policy_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_autoscaling_policy), "__call__" + type(client.transport.delete_autoscaling_policy), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1685,7 +1824,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = AutoscalingPolicyServiceClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -1703,13 +1842,28 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = AutoscalingPolicyServiceClient( credentials=credentials.AnonymousCredentials(), ) assert isinstance( - client._transport, transports.AutoscalingPolicyServiceGrpcTransport, + client.transport, transports.AutoscalingPolicyServiceGrpcTransport, ) @@ -1765,6 +1919,17 @@ def test_autoscaling_policy_service_base_transport_with_credentials_file(): ) +def test_autoscaling_policy_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1beta2.services.autoscaling_policy_service.transports.AutoscalingPolicyServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.AutoscalingPolicyServiceTransport() + adc.assert_called_once() + + def test_autoscaling_policy_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -1797,7 +1962,7 @@ def test_autoscaling_policy_service_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_autoscaling_policy_service_host_with_port(): @@ -1807,185 +1972,119 @@ def test_autoscaling_policy_service_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_autoscaling_policy_service_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.AutoscalingPolicyServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_autoscaling_policy_service_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.AutoscalingPolicyServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint +def test_autoscaling_policy_service_transport_channel_mtls_with_client_cert_source( + transport_class, ): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.AutoscalingPolicyServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.AutoscalingPolicyServiceGrpcTransport, + transports.AutoscalingPolicyServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_autoscaling_policy_service_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_autoscaling_policy_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.AutoscalingPolicyServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_autoscaling_policy_path(): @@ -2013,3 +2112,125 @@ def test_parse_autoscaling_policy_path(): # Check that the path construction is reversible. actual = AutoscalingPolicyServiceClient.parse_autoscaling_policy_path(path) assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = AutoscalingPolicyServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = AutoscalingPolicyServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = AutoscalingPolicyServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = AutoscalingPolicyServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = AutoscalingPolicyServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = AutoscalingPolicyServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = AutoscalingPolicyServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = AutoscalingPolicyServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = AutoscalingPolicyServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = AutoscalingPolicyServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = AutoscalingPolicyServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.AutoscalingPolicyServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = AutoscalingPolicyServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.AutoscalingPolicyServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = AutoscalingPolicyServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/dataproc_v1beta2/test_cluster_controller.py b/tests/unit/gapic/dataproc_v1beta2/test_cluster_controller.py index 1dbc1a36..3e6bab80 100644 --- a/tests/unit/gapic/dataproc_v1beta2/test_cluster_controller.py +++ b/tests/unit/gapic/dataproc_v1beta2/test_cluster_controller.py @@ -31,7 +31,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async -from google.api_core import operation_async +from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError @@ -44,7 +44,6 @@ from google.cloud.dataproc_v1beta2.services.cluster_controller import pagers from google.cloud.dataproc_v1beta2.services.cluster_controller import transports from google.cloud.dataproc_v1beta2.types import clusters -from google.cloud.dataproc_v1beta2.types import clusters as gcd_clusters from google.cloud.dataproc_v1beta2.types import operations from google.cloud.dataproc_v1beta2.types import shared from google.longrunning import operations_pb2 @@ -109,12 +108,12 @@ def test_cluster_controller_client_from_service_account_file(client_class): ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_cluster_controller_client_get_transport_class(): @@ -170,14 +169,14 @@ def test_cluster_controller_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -186,14 +185,14 @@ def test_cluster_controller_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -202,90 +201,185 @@ def test_cluster_controller_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + ClusterControllerClient, + transports.ClusterControllerGrpcTransport, + "grpc", + "true", + ), + ( + ClusterControllerAsyncClient, + transports.ClusterControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + ClusterControllerClient, + transports.ClusterControllerGrpcTransport, + "grpc", + "false", + ), + ( + ClusterControllerAsyncClient, + transports.ClusterControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ClusterControllerClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ClusterControllerClient), +) +@mock.patch.object( + ClusterControllerAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ClusterControllerAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_cluster_controller_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -312,9 +406,9 @@ def test_cluster_controller_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -342,9 +436,9 @@ def test_cluster_controller_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -361,9 +455,9 @@ def test_cluster_controller_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -379,7 +473,7 @@ def test_create_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -400,19 +494,19 @@ def test_create_cluster_from_dict(): @pytest.mark.asyncio -async def test_create_cluster_async(transport: str = "grpc_asyncio"): +async def test_create_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.CreateClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.CreateClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -424,17 +518,22 @@ async def test_create_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.CreateClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_create_cluster_async_from_dict(): + await test_create_cluster_async(request_type=dict) + + def test_create_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -479,9 +578,7 @@ async def test_create_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.create_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -537,7 +634,7 @@ def test_update_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -558,19 +655,19 @@ def test_update_cluster_from_dict(): @pytest.mark.asyncio -async def test_update_cluster_async(transport: str = "grpc_asyncio"): +async def test_update_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.UpdateClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.UpdateClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -582,17 +679,22 @@ async def test_update_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.UpdateClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_update_cluster_async_from_dict(): + await test_update_cluster_async(request_type=dict) + + def test_update_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -645,9 +747,7 @@ async def test_update_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.update_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -711,7 +811,7 @@ def test_delete_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -732,19 +832,19 @@ def test_delete_cluster_from_dict(): @pytest.mark.asyncio -async def test_delete_cluster_async(transport: str = "grpc_asyncio"): +async def test_delete_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.DeleteClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.DeleteClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -756,17 +856,22 @@ async def test_delete_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.DeleteClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_delete_cluster_async_from_dict(): + await test_delete_cluster_async(request_type=dict) + + def test_delete_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -811,9 +916,7 @@ async def test_delete_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -867,7 +970,7 @@ def test_get_cluster(transport: str = "grpc", request_type=clusters.GetClusterRe request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.Cluster( project_id="project_id_value", @@ -884,6 +987,7 @@ def test_get_cluster(transport: str = "grpc", request_type=clusters.GetClusterRe assert args[0] == clusters.GetClusterRequest() # Establish that the response is the type that we expect. + assert isinstance(response, clusters.Cluster) assert response.project_id == "project_id_value" @@ -898,19 +1002,19 @@ def test_get_cluster_from_dict(): @pytest.mark.asyncio -async def test_get_cluster_async(transport: str = "grpc_asyncio"): +async def test_get_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.GetClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.GetClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( clusters.Cluster( @@ -926,7 +1030,7 @@ async def test_get_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.GetClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, clusters.Cluster) @@ -938,11 +1042,16 @@ async def test_get_cluster_async(transport: str = "grpc_asyncio"): assert response.cluster_uuid == "cluster_uuid_value" +@pytest.mark.asyncio +async def test_get_cluster_async_from_dict(): + await test_get_cluster_async(request_type=dict) + + def test_get_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_cluster), "__call__") as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.Cluster() @@ -987,9 +1096,7 @@ async def test_get_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.get_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.Cluster() @@ -1043,7 +1150,7 @@ def test_list_clusters( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.ListClustersResponse( next_page_token="next_page_token_value", @@ -1058,6 +1165,7 @@ def test_list_clusters( assert args[0] == clusters.ListClustersRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListClustersPager) assert response.next_page_token == "next_page_token_value" @@ -1068,19 +1176,19 @@ def test_list_clusters_from_dict(): @pytest.mark.asyncio -async def test_list_clusters_async(transport: str = "grpc_asyncio"): +async def test_list_clusters_async( + transport: str = "grpc_asyncio", request_type=clusters.ListClustersRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.ListClustersRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_clusters), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( clusters.ListClustersResponse(next_page_token="next_page_token_value",) @@ -1092,7 +1200,7 @@ async def test_list_clusters_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.ListClustersRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListClustersAsyncPager) @@ -1100,11 +1208,16 @@ async def test_list_clusters_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_clusters_async_from_dict(): + await test_list_clusters_async(request_type=dict) + + def test_list_clusters_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.ListClustersResponse() @@ -1147,9 +1260,7 @@ async def test_list_clusters_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_clusters), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = clusters.ListClustersResponse() @@ -1195,7 +1306,7 @@ def test_list_clusters_pager(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( clusters.ListClustersResponse( @@ -1226,7 +1337,7 @@ def test_list_clusters_pages(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_clusters), "__call__") as call: + with mock.patch.object(type(client.transport.list_clusters), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( clusters.ListClustersResponse( @@ -1243,8 +1354,8 @@ def test_list_clusters_pages(): RuntimeError, ) pages = list(client.list_clusters(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1253,9 +1364,7 @@ async def test_list_clusters_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_clusters), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_clusters), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1288,9 +1397,7 @@ async def test_list_clusters_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_clusters), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_clusters), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1308,10 +1415,10 @@ async def test_list_clusters_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_clusters(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_clusters(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_diagnose_cluster( @@ -1326,9 +1433,7 @@ def test_diagnose_cluster( request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -1349,19 +1454,19 @@ def test_diagnose_cluster_from_dict(): @pytest.mark.asyncio -async def test_diagnose_cluster_async(transport: str = "grpc_asyncio"): +async def test_diagnose_cluster_async( + transport: str = "grpc_asyncio", request_type=clusters.DiagnoseClusterRequest +): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = clusters.DiagnoseClusterRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/spam") @@ -1373,19 +1478,22 @@ async def test_diagnose_cluster_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == clusters.DiagnoseClusterRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_diagnose_cluster_async_from_dict(): + await test_diagnose_cluster_async(request_type=dict) + + def test_diagnose_cluster_flattened(): client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1430,9 +1538,7 @@ async def test_diagnose_cluster_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.diagnose_cluster), "__call__" - ) as call: + with mock.patch.object(type(client.transport.diagnose_cluster), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1512,7 +1618,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = ClusterControllerClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -1530,10 +1636,25 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.ClusterControllerGrpcTransport, + transports.ClusterControllerGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = ClusterControllerClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.ClusterControllerGrpcTransport,) + assert isinstance(client.transport, transports.ClusterControllerGrpcTransport,) def test_cluster_controller_base_transport_error(): @@ -1594,6 +1715,17 @@ def test_cluster_controller_base_transport_with_credentials_file(): ) +def test_cluster_controller_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1beta2.services.cluster_controller.transports.ClusterControllerTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ClusterControllerTransport() + adc.assert_called_once() + + def test_cluster_controller_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -1626,7 +1758,7 @@ def test_cluster_controller_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_cluster_controller_host_with_port(): @@ -1636,192 +1768,126 @@ def test_cluster_controller_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_cluster_controller_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.ClusterControllerGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_cluster_controller_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.ClusterControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_cluster_controller_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.ClusterControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_cluster_controller_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.ClusterControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.ClusterControllerGrpcTransport, + transports.ClusterControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_cluster_controller_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint +def test_cluster_controller_transport_channel_mtls_with_client_cert_source( + transport_class, ): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.ClusterControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.ClusterControllerGrpcTransport, + transports.ClusterControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_cluster_controller_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_cluster_controller_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.ClusterControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_cluster_controller_grpc_lro_client(): client = ClusterControllerClient( credentials=credentials.AnonymousCredentials(), transport="grpc", ) - transport = client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsClient,) @@ -1834,10 +1900,132 @@ def test_cluster_controller_grpc_lro_async_client(): client = ClusterControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - transport = client._client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = ClusterControllerClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = ClusterControllerClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = ClusterControllerClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = ClusterControllerClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = ClusterControllerClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = ClusterControllerClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = ClusterControllerClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = ClusterControllerClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = ClusterControllerClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = ClusterControllerClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ClusterControllerClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.ClusterControllerTransport, "_prep_wrapped_messages" + ) as prep: + client = ClusterControllerClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.ClusterControllerTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = ClusterControllerClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/dataproc_v1beta2/test_job_controller.py b/tests/unit/gapic/dataproc_v1beta2/test_job_controller.py index 89e093e6..9c4ebd86 100644 --- a/tests/unit/gapic/dataproc_v1beta2/test_job_controller.py +++ b/tests/unit/gapic/dataproc_v1beta2/test_job_controller.py @@ -31,7 +31,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async -from google.api_core import operation_async +from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError @@ -42,7 +42,6 @@ from google.cloud.dataproc_v1beta2.services.job_controller import pagers from google.cloud.dataproc_v1beta2.services.job_controller import transports from google.cloud.dataproc_v1beta2.types import jobs -from google.cloud.dataproc_v1beta2.types import jobs as gcd_jobs from google.longrunning import operations_pb2 from google.oauth2 import service_account from google.protobuf import field_mask_pb2 as field_mask # type: ignore @@ -103,12 +102,12 @@ def test_job_controller_client_from_service_account_file(client_class): ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_job_controller_client_get_transport_class(): @@ -164,14 +163,14 @@ def test_job_controller_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -180,14 +179,14 @@ def test_job_controller_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -196,90 +195,175 @@ def test_job_controller_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobControllerClient, transports.JobControllerGrpcTransport, "grpc", "true"), + ( + JobControllerAsyncClient, + transports.JobControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobControllerClient, transports.JobControllerGrpcTransport, "grpc", "false"), + ( + JobControllerAsyncClient, + transports.JobControllerGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobControllerClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobControllerClient), +) +@mock.patch.object( + JobControllerAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobControllerAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_job_controller_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -306,9 +390,9 @@ def test_job_controller_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -336,9 +420,9 @@ def test_job_controller_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -355,9 +439,9 @@ def test_job_controller_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +455,7 @@ def test_submit_job(transport: str = "grpc", request_type=jobs.SubmitJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.submit_job), "__call__") as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( submitted_by="submitted_by_value", @@ -391,6 +475,7 @@ def test_submit_job(transport: str = "grpc", request_type=jobs.SubmitJobRequest) assert args[0] == jobs.SubmitJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.submitted_by == "submitted_by_value" @@ -409,19 +494,19 @@ def test_submit_job_from_dict(): @pytest.mark.asyncio -async def test_submit_job_async(transport: str = "grpc_asyncio"): +async def test_submit_job_async( + transport: str = "grpc_asyncio", request_type=jobs.SubmitJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.SubmitJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.submit_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -439,7 +524,7 @@ async def test_submit_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.SubmitJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -455,11 +540,16 @@ async def test_submit_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_submit_job_async_from_dict(): + await test_submit_job_async(request_type=dict) + + def test_submit_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.submit_job), "__call__") as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -504,9 +594,7 @@ async def test_submit_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.submit_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.submit_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -561,7 +649,7 @@ def test_submit_job_as_operation( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -583,18 +671,20 @@ def test_submit_job_as_operation_from_dict(): @pytest.mark.asyncio -async def test_submit_job_as_operation_async(transport: str = "grpc_asyncio"): +async def test_submit_job_as_operation_async( + transport: str = "grpc_asyncio", request_type=jobs.SubmitJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.SubmitJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -607,18 +697,23 @@ async def test_submit_job_as_operation_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.SubmitJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_submit_job_as_operation_async_from_dict(): + await test_submit_job_as_operation_async(request_type=dict) + + def test_submit_job_as_operation_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -665,7 +760,7 @@ async def test_submit_job_as_operation_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.submit_job_as_operation), "__call__" + type(client.transport.submit_job_as_operation), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -720,7 +815,7 @@ def test_get_job(transport: str = "grpc", request_type=jobs.GetJobRequest): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( submitted_by="submitted_by_value", @@ -740,6 +835,7 @@ def test_get_job(transport: str = "grpc", request_type=jobs.GetJobRequest): assert args[0] == jobs.GetJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.submitted_by == "submitted_by_value" @@ -758,17 +854,19 @@ def test_get_job_from_dict(): @pytest.mark.asyncio -async def test_get_job_async(transport: str = "grpc_asyncio"): +async def test_get_job_async( + transport: str = "grpc_asyncio", request_type=jobs.GetJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.GetJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -786,7 +884,7 @@ async def test_get_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.GetJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -802,11 +900,16 @@ async def test_get_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_get_job_async_from_dict(): + await test_get_job_async(request_type=dict) + + def test_get_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -847,7 +950,7 @@ async def test_get_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.get_job), "__call__") as call: + with mock.patch.object(type(client.transport.get_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -895,7 +998,7 @@ def test_list_jobs(transport: str = "grpc", request_type=jobs.ListJobsRequest): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.ListJobsResponse( next_page_token="next_page_token_value", @@ -910,6 +1013,7 @@ def test_list_jobs(transport: str = "grpc", request_type=jobs.ListJobsRequest): assert args[0] == jobs.ListJobsRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListJobsPager) assert response.next_page_token == "next_page_token_value" @@ -920,19 +1024,19 @@ def test_list_jobs_from_dict(): @pytest.mark.asyncio -async def test_list_jobs_async(transport: str = "grpc_asyncio"): +async def test_list_jobs_async( + transport: str = "grpc_asyncio", request_type=jobs.ListJobsRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.ListJobsRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_jobs), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.ListJobsResponse(next_page_token="next_page_token_value",) @@ -944,7 +1048,7 @@ async def test_list_jobs_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.ListJobsRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListJobsAsyncPager) @@ -952,11 +1056,16 @@ async def test_list_jobs_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_jobs_async_from_dict(): + await test_list_jobs_async(request_type=dict) + + def test_list_jobs_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.ListJobsResponse() @@ -997,9 +1106,7 @@ async def test_list_jobs_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_jobs), "__call__" - ) as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.ListJobsResponse() @@ -1043,7 +1150,7 @@ def test_list_jobs_pager(): client = JobControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( jobs.ListJobsResponse( @@ -1069,7 +1176,7 @@ def test_list_jobs_pages(): client = JobControllerClient(credentials=credentials.AnonymousCredentials,) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_jobs), "__call__") as call: + with mock.patch.object(type(client.transport.list_jobs), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( jobs.ListJobsResponse( @@ -1081,8 +1188,8 @@ def test_list_jobs_pages(): RuntimeError, ) pages = list(client.list_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1091,9 +1198,7 @@ async def test_list_jobs_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_jobs), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_jobs), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1121,9 +1226,7 @@ async def test_list_jobs_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_jobs), - "__call__", - new_callable=mock.AsyncMock, + type(client.transport.list_jobs), "__call__", new_callable=mock.AsyncMock ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1136,10 +1239,10 @@ async def test_list_jobs_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_jobs(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_update_job(transport: str = "grpc", request_type=jobs.UpdateJobRequest): @@ -1152,7 +1255,7 @@ def test_update_job(transport: str = "grpc", request_type=jobs.UpdateJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_job), "__call__") as call: + with mock.patch.object(type(client.transport.update_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( submitted_by="submitted_by_value", @@ -1172,6 +1275,7 @@ def test_update_job(transport: str = "grpc", request_type=jobs.UpdateJobRequest) assert args[0] == jobs.UpdateJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.submitted_by == "submitted_by_value" @@ -1190,19 +1294,19 @@ def test_update_job_from_dict(): @pytest.mark.asyncio -async def test_update_job_async(transport: str = "grpc_asyncio"): +async def test_update_job_async( + transport: str = "grpc_asyncio", request_type=jobs.UpdateJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.UpdateJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.update_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -1220,7 +1324,7 @@ async def test_update_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.UpdateJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -1236,6 +1340,11 @@ async def test_update_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_update_job_async_from_dict(): + await test_update_job_async(request_type=dict) + + def test_cancel_job(transport: str = "grpc", request_type=jobs.CancelJobRequest): client = JobControllerClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1246,7 +1355,7 @@ def test_cancel_job(transport: str = "grpc", request_type=jobs.CancelJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.cancel_job), "__call__") as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job( submitted_by="submitted_by_value", @@ -1266,6 +1375,7 @@ def test_cancel_job(transport: str = "grpc", request_type=jobs.CancelJobRequest) assert args[0] == jobs.CancelJobRequest() # Establish that the response is the type that we expect. + assert isinstance(response, jobs.Job) assert response.submitted_by == "submitted_by_value" @@ -1284,19 +1394,19 @@ def test_cancel_job_from_dict(): @pytest.mark.asyncio -async def test_cancel_job_async(transport: str = "grpc_asyncio"): +async def test_cancel_job_async( + transport: str = "grpc_asyncio", request_type=jobs.CancelJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.CancelJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( jobs.Job( @@ -1314,7 +1424,7 @@ async def test_cancel_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.CancelJobRequest() # Establish that the response is the type that we expect. assert isinstance(response, jobs.Job) @@ -1330,11 +1440,16 @@ async def test_cancel_job_async(transport: str = "grpc_asyncio"): assert response.done is True +@pytest.mark.asyncio +async def test_cancel_job_async_from_dict(): + await test_cancel_job_async(request_type=dict) + + def test_cancel_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.cancel_job), "__call__") as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -1375,9 +1490,7 @@ async def test_cancel_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.cancel_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = jobs.Job() @@ -1425,7 +1538,7 @@ def test_delete_job(transport: str = "grpc", request_type=jobs.DeleteJobRequest) request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_job), "__call__") as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1446,19 +1559,19 @@ def test_delete_job_from_dict(): @pytest.mark.asyncio -async def test_delete_job_async(transport: str = "grpc_asyncio"): +async def test_delete_job_async( + transport: str = "grpc_asyncio", request_type=jobs.DeleteJobRequest +): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = jobs.DeleteJobRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1468,17 +1581,22 @@ async def test_delete_job_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == jobs.DeleteJobRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_delete_job_async_from_dict(): + await test_delete_job_async(request_type=dict) + + def test_delete_job_flattened(): client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_job), "__call__") as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1519,9 +1637,7 @@ async def test_delete_job_flattened_async(): client = JobControllerAsyncClient(credentials=credentials.AnonymousCredentials(),) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_job), "__call__" - ) as call: + with mock.patch.object(type(client.transport.delete_job), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1595,7 +1711,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = JobControllerClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -1613,10 +1729,25 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.JobControllerGrpcTransport, + transports.JobControllerGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = JobControllerClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.JobControllerGrpcTransport,) + assert isinstance(client.transport, transports.JobControllerGrpcTransport,) def test_job_controller_base_transport_error(): @@ -1678,6 +1809,17 @@ def test_job_controller_base_transport_with_credentials_file(): ) +def test_job_controller_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1beta2.services.job_controller.transports.JobControllerTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobControllerTransport() + adc.assert_called_once() + + def test_job_controller_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -1710,7 +1852,7 @@ def test_job_controller_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_job_controller_host_with_port(): @@ -1720,192 +1862,124 @@ def test_job_controller_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_job_controller_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.JobControllerGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_job_controller_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.JobControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_job_controller_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.JobControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_job_controller_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.JobControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.JobControllerGrpcTransport, + transports.JobControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_job_controller_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.JobControllerGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel +def test_job_controller_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.JobControllerGrpcTransport, + transports.JobControllerGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_job_controller_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_job_controller_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.JobControllerGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_job_controller_grpc_lro_client(): client = JobControllerClient( credentials=credentials.AnonymousCredentials(), transport="grpc", ) - transport = client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsClient,) @@ -1918,10 +1992,132 @@ def test_job_controller_grpc_lro_async_client(): client = JobControllerAsyncClient( credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - transport = client._client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = JobControllerClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = JobControllerClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = JobControllerClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = JobControllerClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = JobControllerClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = JobControllerClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = JobControllerClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = JobControllerClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = JobControllerClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = JobControllerClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = JobControllerClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.JobControllerTransport, "_prep_wrapped_messages" + ) as prep: + client = JobControllerClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.JobControllerTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = JobControllerClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/dataproc_v1beta2/test_workflow_template_service.py b/tests/unit/gapic/dataproc_v1beta2/test_workflow_template_service.py index 07ea3560..ef40c09d 100644 --- a/tests/unit/gapic/dataproc_v1beta2/test_workflow_template_service.py +++ b/tests/unit/gapic/dataproc_v1beta2/test_workflow_template_service.py @@ -31,7 +31,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async -from google.api_core import operation_async +from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError @@ -44,9 +44,7 @@ from google.cloud.dataproc_v1beta2.services.workflow_template_service import pagers from google.cloud.dataproc_v1beta2.services.workflow_template_service import transports from google.cloud.dataproc_v1beta2.types import clusters -from google.cloud.dataproc_v1beta2.types import clusters as gcd_clusters from google.cloud.dataproc_v1beta2.types import jobs -from google.cloud.dataproc_v1beta2.types import jobs as gcd_jobs from google.cloud.dataproc_v1beta2.types import shared from google.cloud.dataproc_v1beta2.types import workflow_templates from google.longrunning import operations_pb2 @@ -110,12 +108,12 @@ def test_workflow_template_service_client_from_service_account_file(client_class ) as factory: factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds + assert client.transport._credentials == creds - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_workflow_template_service_client_get_transport_class(): @@ -175,14 +173,14 @@ def test_workflow_template_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "never"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -191,14 +189,14 @@ def test_workflow_template_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "always"}): + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class() @@ -207,90 +205,185 @@ def test_workflow_template_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceGrpcTransport, + "grpc", + "true", + ), + ( + WorkflowTemplateServiceAsyncClient, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + WorkflowTemplateServiceClient, + transports.WorkflowTemplateServiceGrpcTransport, + "grpc", + "false", + ), + ( + WorkflowTemplateServiceAsyncClient, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + WorkflowTemplateServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(WorkflowTemplateServiceClient), +) +@mock.patch.object( + WorkflowTemplateServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(WorkflowTemplateServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_workflow_template_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): options = client_options.ClientOptions( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=client_cert_source_callback, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", and default_client_cert_source is provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): - with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=True, + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds ): patched.return_value = None - client = client_class() + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, + host=expected_host, scopes=None, - api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=expected_ssl_channel_creds, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) - # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is - # "auto", but client_cert_source and default_client_cert_source are None. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "auto"}): + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.mtls.has_default_client_cert_source", - return_value=False, + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None ): - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, - quota_project_id="octopus", - ) + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) @pytest.mark.parametrize( @@ -321,9 +414,9 @@ def test_workflow_template_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -355,9 +448,9 @@ def test_workflow_template_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - api_mtls_endpoint=client.DEFAULT_ENDPOINT, - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -374,9 +467,9 @@ def test_workflow_template_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - api_mtls_endpoint="squid.clam.whelk", - client_cert_source=None, + ssl_channel_credentials=None, quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +487,7 @@ def test_create_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate( @@ -410,6 +503,7 @@ def test_create_workflow_template( assert args[0] == workflow_templates.CreateWorkflowTemplateRequest() # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) assert response.id == "id_value" @@ -424,18 +518,21 @@ def test_create_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_create_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_create_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.CreateWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.CreateWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -450,7 +547,7 @@ async def test_create_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.CreateWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, workflow_templates.WorkflowTemplate) @@ -462,6 +559,11 @@ async def test_create_workflow_template_async(transport: str = "grpc_asyncio"): assert response.version == 774 +@pytest.mark.asyncio +async def test_create_workflow_template_async_from_dict(): + await test_create_workflow_template_async(request_type=dict) + + def test_create_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -474,7 +576,7 @@ def test_create_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: call.return_value = workflow_templates.WorkflowTemplate() @@ -503,7 +605,7 @@ async def test_create_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.WorkflowTemplate() @@ -528,7 +630,7 @@ def test_create_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -573,7 +675,7 @@ async def test_create_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.create_workflow_template), "__call__" + type(client.transport.create_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -627,7 +729,7 @@ def test_get_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate( @@ -643,6 +745,7 @@ def test_get_workflow_template( assert args[0] == workflow_templates.GetWorkflowTemplateRequest() # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) assert response.id == "id_value" @@ -657,18 +760,21 @@ def test_get_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_get_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_get_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.GetWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.GetWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -683,7 +789,7 @@ async def test_get_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.GetWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, workflow_templates.WorkflowTemplate) @@ -695,6 +801,11 @@ async def test_get_workflow_template_async(transport: str = "grpc_asyncio"): assert response.version == 774 +@pytest.mark.asyncio +async def test_get_workflow_template_async_from_dict(): + await test_get_workflow_template_async(request_type=dict) + + def test_get_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -707,7 +818,7 @@ def test_get_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: call.return_value = workflow_templates.WorkflowTemplate() @@ -736,7 +847,7 @@ async def test_get_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.WorkflowTemplate() @@ -761,7 +872,7 @@ def test_get_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -799,7 +910,7 @@ async def test_get_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.get_workflow_template), "__call__" + type(client.transport.get_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -847,7 +958,7 @@ def test_instantiate_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -869,18 +980,21 @@ def test_instantiate_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_instantiate_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_instantiate_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.InstantiateWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.InstantiateWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -893,12 +1007,17 @@ async def test_instantiate_workflow_template_async(transport: str = "grpc_asynci assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.InstantiateWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_instantiate_workflow_template_async_from_dict(): + await test_instantiate_workflow_template_async(request_type=dict) + + def test_instantiate_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -911,7 +1030,7 @@ def test_instantiate_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: call.return_value = operations_pb2.Operation(name="operations/op") @@ -940,7 +1059,7 @@ async def test_instantiate_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/op") @@ -965,7 +1084,7 @@ def test_instantiate_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1009,7 +1128,7 @@ async def test_instantiate_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_workflow_template), "__call__" + type(client.transport.instantiate_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1063,7 +1182,7 @@ def test_instantiate_inline_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/spam") @@ -1087,6 +1206,7 @@ def test_instantiate_inline_workflow_template_from_dict(): @pytest.mark.asyncio async def test_instantiate_inline_workflow_template_async( transport: str = "grpc_asyncio", + request_type=workflow_templates.InstantiateInlineWorkflowTemplateRequest, ): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1094,11 +1214,11 @@ async def test_instantiate_inline_workflow_template_async( # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.InstantiateInlineWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1111,12 +1231,17 @@ async def test_instantiate_inline_workflow_template_async( assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.InstantiateInlineWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, future.Future) +@pytest.mark.asyncio +async def test_instantiate_inline_workflow_template_async_from_dict(): + await test_instantiate_inline_workflow_template_async(request_type=dict) + + def test_instantiate_inline_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1129,7 +1254,7 @@ def test_instantiate_inline_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: call.return_value = operations_pb2.Operation(name="operations/op") @@ -1158,7 +1283,7 @@ async def test_instantiate_inline_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( operations_pb2.Operation(name="operations/op") @@ -1183,7 +1308,7 @@ def test_instantiate_inline_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1228,7 +1353,7 @@ async def test_instantiate_inline_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.instantiate_inline_workflow_template), "__call__" + type(client.transport.instantiate_inline_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = operations_pb2.Operation(name="operations/op") @@ -1283,7 +1408,7 @@ def test_update_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate( @@ -1299,6 +1424,7 @@ def test_update_workflow_template( assert args[0] == workflow_templates.UpdateWorkflowTemplateRequest() # Establish that the response is the type that we expect. + assert isinstance(response, workflow_templates.WorkflowTemplate) assert response.id == "id_value" @@ -1313,18 +1439,21 @@ def test_update_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_update_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_update_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.UpdateWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.UpdateWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1339,7 +1468,7 @@ async def test_update_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.UpdateWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert isinstance(response, workflow_templates.WorkflowTemplate) @@ -1351,6 +1480,11 @@ async def test_update_workflow_template_async(transport: str = "grpc_asyncio"): assert response.version == 774 +@pytest.mark.asyncio +async def test_update_workflow_template_async_from_dict(): + await test_update_workflow_template_async(request_type=dict) + + def test_update_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1363,7 +1497,7 @@ def test_update_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: call.return_value = workflow_templates.WorkflowTemplate() @@ -1394,7 +1528,7 @@ async def test_update_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.WorkflowTemplate() @@ -1421,7 +1555,7 @@ def test_update_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -1462,7 +1596,7 @@ async def test_update_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.update_workflow_template), "__call__" + type(client.transport.update_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.WorkflowTemplate() @@ -1513,7 +1647,7 @@ def test_list_workflow_templates( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.ListWorkflowTemplatesResponse( @@ -1529,6 +1663,7 @@ def test_list_workflow_templates( assert args[0] == workflow_templates.ListWorkflowTemplatesRequest() # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListWorkflowTemplatesPager) assert response.next_page_token == "next_page_token_value" @@ -1539,18 +1674,21 @@ def test_list_workflow_templates_from_dict(): @pytest.mark.asyncio -async def test_list_workflow_templates_async(transport: str = "grpc_asyncio"): +async def test_list_workflow_templates_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.ListWorkflowTemplatesRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.ListWorkflowTemplatesRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( @@ -1565,7 +1703,7 @@ async def test_list_workflow_templates_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.ListWorkflowTemplatesRequest() # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListWorkflowTemplatesAsyncPager) @@ -1573,6 +1711,11 @@ async def test_list_workflow_templates_async(transport: str = "grpc_asyncio"): assert response.next_page_token == "next_page_token_value" +@pytest.mark.asyncio +async def test_list_workflow_templates_async_from_dict(): + await test_list_workflow_templates_async(request_type=dict) + + def test_list_workflow_templates_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1585,7 +1728,7 @@ def test_list_workflow_templates_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: call.return_value = workflow_templates.ListWorkflowTemplatesResponse() @@ -1614,7 +1757,7 @@ async def test_list_workflow_templates_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( workflow_templates.ListWorkflowTemplatesResponse() @@ -1639,7 +1782,7 @@ def test_list_workflow_templates_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.ListWorkflowTemplatesResponse() @@ -1677,7 +1820,7 @@ async def test_list_workflow_templates_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = workflow_templates.ListWorkflowTemplatesResponse() @@ -1718,7 +1861,7 @@ def test_list_workflow_templates_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1766,7 +1909,7 @@ def test_list_workflow_templates_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.list_workflow_templates), "__call__" + type(client.transport.list_workflow_templates), "__call__" ) as call: # Set the response to a series of pages. call.side_effect = ( @@ -1794,8 +1937,8 @@ def test_list_workflow_templates_pages(): RuntimeError, ) pages = list(client.list_workflow_templates(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token @pytest.mark.asyncio @@ -1806,7 +1949,7 @@ async def test_list_workflow_templates_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), + type(client.transport.list_workflow_templates), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1855,7 +1998,7 @@ async def test_list_workflow_templates_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.list_workflow_templates), + type(client.transport.list_workflow_templates), "__call__", new_callable=mock.AsyncMock, ) as call: @@ -1885,10 +2028,10 @@ async def test_list_workflow_templates_async_pages(): RuntimeError, ) pages = [] - async for page in (await client.list_workflow_templates(request={})).pages: - pages.append(page) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token + async for page_ in (await client.list_workflow_templates(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token def test_delete_workflow_template( @@ -1905,7 +2048,7 @@ def test_delete_workflow_template( # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1927,18 +2070,21 @@ def test_delete_workflow_template_from_dict(): @pytest.mark.asyncio -async def test_delete_workflow_template_async(transport: str = "grpc_asyncio"): +async def test_delete_workflow_template_async( + transport: str = "grpc_asyncio", + request_type=workflow_templates.DeleteWorkflowTemplateRequest, +): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport=transport, ) # Everything is optional in proto3 as far as the runtime is concerned, # and we are mocking out the actual API, so just send an empty request. - request = workflow_templates.DeleteWorkflowTemplateRequest() + request = request_type() # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -1949,12 +2095,17 @@ async def test_delete_workflow_template_async(transport: str = "grpc_asyncio"): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] - assert args[0] == request + assert args[0] == workflow_templates.DeleteWorkflowTemplateRequest() # Establish that the response is the type that we expect. assert response is None +@pytest.mark.asyncio +async def test_delete_workflow_template_async_from_dict(): + await test_delete_workflow_template_async(request_type=dict) + + def test_delete_workflow_template_field_headers(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1967,7 +2118,7 @@ def test_delete_workflow_template_field_headers(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: call.return_value = None @@ -1996,7 +2147,7 @@ async def test_delete_workflow_template_field_headers_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) @@ -2019,7 +2170,7 @@ def test_delete_workflow_template_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2057,7 +2208,7 @@ async def test_delete_workflow_template_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client._client._transport.delete_workflow_template), "__call__" + type(client.transport.delete_workflow_template), "__call__" ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2125,7 +2276,7 @@ def test_transport_instance(): credentials=credentials.AnonymousCredentials(), ) client = WorkflowTemplateServiceClient(transport=transport) - assert client._transport is transport + assert client.transport is transport def test_transport_get_channel(): @@ -2143,13 +2294,28 @@ def test_transport_get_channel(): assert channel +@pytest.mark.parametrize( + "transport_class", + [ + transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), ) assert isinstance( - client._transport, transports.WorkflowTemplateServiceGrpcTransport, + client.transport, transports.WorkflowTemplateServiceGrpcTransport, ) @@ -2212,6 +2378,17 @@ def test_workflow_template_service_base_transport_with_credentials_file(): ) +def test_workflow_template_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.dataproc_v1beta2.services.workflow_template_service.transports.WorkflowTemplateServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.WorkflowTemplateServiceTransport() + adc.assert_called_once() + + def test_workflow_template_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. with mock.patch.object(auth, "default") as adc: @@ -2244,7 +2421,7 @@ def test_workflow_template_service_host_no_port(): api_endpoint="dataproc.googleapis.com" ), ) - assert client._transport._host == "dataproc.googleapis.com:443" + assert client.transport._host == "dataproc.googleapis.com:443" def test_workflow_template_service_host_with_port(): @@ -2254,192 +2431,126 @@ def test_workflow_template_service_host_with_port(): api_endpoint="dataproc.googleapis.com:8000" ), ) - assert client._transport._host == "dataproc.googleapis.com:8000" + assert client.transport._host == "dataproc.googleapis.com:8000" def test_workflow_template_service_grpc_transport_channel(): channel = grpc.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.WorkflowTemplateServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called + assert transport._ssl_channel_credentials == None def test_workflow_template_service_grpc_asyncio_transport_channel(): channel = aio.insecure_channel("http://localhost/") - # Check that if channel is provided, mtls endpoint and client_cert_source - # won't be used. - callback = mock.MagicMock() + # Check that channel is used if provided. transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=callback, + host="squid.clam.whelk", channel=channel, ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" - assert not callback.called - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_workflow_template_service_grpc_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.WorkflowTemplateServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@mock.patch("grpc.ssl_channel_credentials", autospec=True) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_workflow_template_service_grpc_asyncio_transport_channel_mtls_with_client_cert_source( - grpc_create_channel, grpc_ssl_channel_cred -): - # Check that if channel is None, but api_mtls_endpoint and client_cert_source - # are provided, then a mTLS channel will be created. - mock_cred = mock.Mock() - - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers.create_channel", autospec=True) -def test_workflow_template_service_grpc_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint +def test_workflow_template_service_transport_channel_mtls_with_client_cert_source( + transport_class, ): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - mock_cred = mock.Mock() - transport = transports.WorkflowTemplateServiceGrpcTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize( - "api_mtls_endpoint", ["mtls.squid.clam.whelk", "mtls.squid.clam.whelk:443"] + "transport_class", + [ + transports.WorkflowTemplateServiceGrpcTransport, + transports.WorkflowTemplateServiceGrpcAsyncIOTransport, + ], ) -@mock.patch("google.api_core.grpc_helpers_async.create_channel", autospec=True) -def test_workflow_template_service_grpc_asyncio_transport_channel_mtls_with_adc( - grpc_create_channel, api_mtls_endpoint -): - # Check that if channel and client_cert_source are None, but api_mtls_endpoint - # is provided, then a mTLS channel will be created with SSL ADC. - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - # Mock google.auth.transport.grpc.SslCredentials class. +def test_workflow_template_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - mock_cred = mock.Mock() - transport = transports.WorkflowTemplateServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint=api_mtls_endpoint, - client_cert_source=None, - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel def test_workflow_template_service_grpc_lro_client(): client = WorkflowTemplateServiceClient( credentials=credentials.AnonymousCredentials(), transport="grpc", ) - transport = client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsClient,) @@ -2452,7 +2563,7 @@ def test_workflow_template_service_grpc_lro_async_client(): client = WorkflowTemplateServiceAsyncClient( credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - transport = client._client._transport + transport = client.transport # Ensure that we have a api-core operations client. assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) @@ -2486,3 +2597,125 @@ def test_parse_workflow_template_path(): # Check that the path construction is reversible. actual = WorkflowTemplateServiceClient.parse_workflow_template_path(path) assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = WorkflowTemplateServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = WorkflowTemplateServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = WorkflowTemplateServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = WorkflowTemplateServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = WorkflowTemplateServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = WorkflowTemplateServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = WorkflowTemplateServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = WorkflowTemplateServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = WorkflowTemplateServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = WorkflowTemplateServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = WorkflowTemplateServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.WorkflowTemplateServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = WorkflowTemplateServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.WorkflowTemplateServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = WorkflowTemplateServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info)