From 1fbdeceee5eba78233b913885be2cbffc3ca7904 Mon Sep 17 00:00:00 2001 From: Yoshi Automation Bot Date: Wed, 7 Apr 2021 13:14:02 -0700 Subject: [PATCH] fix: use correct retry deadlines (#63) This PR was generated using Autosynth. :rainbow: Synth log will be available here: https://source.cloud.google.com/results/invocations/f041d1b5-526a-41af-bdb4-4cc8c013c941/targets - [ ] To automatically regenerate this PR, check this box. (May take up to 24 hours.) PiperOrigin-RevId: 364411656 Source-Link: https://github.com/googleapis/googleapis/commit/149a3a84c29c9b8189576c7442ccb6dcf6a8f95b PiperOrigin-RevId: 361662015 Source-Link: https://github.com/googleapis/googleapis/commit/28a591963253d52ce3a25a918cafbdd9928de8cf PiperOrigin-RevId: 359562873 Source-Link: https://github.com/googleapis/googleapis/commit/07932bb995e7dc91b43620ea8402c6668c7d102c PiperOrigin-RevId: 355923884 Source-Link: https://github.com/googleapis/googleapis/commit/5e3dacee19405529b841b53797df799c2383536c PiperOrigin-RevId: 354996675 Source-Link: https://github.com/googleapis/googleapis/commit/20712b8fe95001b312f62c6c5f33e3e3ec92cfaf --- .../services/iam_credentials/async_client.py | 36 ++- .../services/iam_credentials/client.py | 42 ++- .../iam_credentials/transports/base.py | 22 +- .../iam_credentials/transports/grpc.py | 112 ++++---- .../transports/grpc_asyncio.py | 120 ++++---- .../iam_credentials_v1/types/__init__.py | 8 +- synth.metadata | 4 +- tests/unit/gapic/credentials_v1/__init__.py | 15 + .../credentials_v1/test_iam_credentials.py | 260 ++++++++++++------ 9 files changed, 368 insertions(+), 251 deletions(-) diff --git a/google/cloud/iam_credentials_v1/services/iam_credentials/async_client.py b/google/cloud/iam_credentials_v1/services/iam_credentials/async_client.py index 557b525..443628e 100644 --- a/google/cloud/iam_credentials_v1/services/iam_credentials/async_client.py +++ b/google/cloud/iam_credentials_v1/services/iam_credentials/async_client.py @@ -89,8 +89,36 @@ class IAMCredentialsAsyncClient: IAMCredentialsClient.parse_common_location_path ) - from_service_account_info = IAMCredentialsClient.from_service_account_info - from_service_account_file = IAMCredentialsClient.from_service_account_file + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IAMCredentialsAsyncClient: The constructed client. + """ + return IAMCredentialsClient.from_service_account_info.__func__(IAMCredentialsAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + IAMCredentialsAsyncClient: The constructed client. + """ + return IAMCredentialsClient.from_service_account_file.__func__(IAMCredentialsAsyncClient, filename, *args, **kwargs) # type: ignore + from_service_account_json = from_service_account_file @property @@ -270,6 +298,7 @@ async def generate_access_token( predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -397,6 +426,7 @@ async def generate_id_token( predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -510,6 +540,7 @@ async def sign_blob( predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, @@ -626,6 +657,7 @@ async def sign_jwt( predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=DEFAULT_CLIENT_INFO, diff --git a/google/cloud/iam_credentials_v1/services/iam_credentials/client.py b/google/cloud/iam_credentials_v1/services/iam_credentials/client.py index 5614db4..1261fb9 100644 --- a/google/cloud/iam_credentials_v1/services/iam_credentials/client.py +++ b/google/cloud/iam_credentials_v1/services/iam_credentials/client.py @@ -296,21 +296,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -353,7 +349,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -460,14 +456,13 @@ def generate_access_token( if name is not None: request.name = name + if delegates is not None: + request.delegates = delegates + if scope is not None: + request.scope = scope if lifetime is not None: request.lifetime = lifetime - if delegates: - request.delegates.extend(delegates) - if scope: - request.scope.extend(scope) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.generate_access_token] @@ -580,14 +575,13 @@ def generate_id_token( if name is not None: request.name = name + if delegates is not None: + request.delegates = delegates if audience is not None: request.audience = audience if include_email is not None: request.include_email = include_email - if delegates: - request.delegates.extend(delegates) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.generate_id_token] @@ -688,12 +682,11 @@ def sign_blob( if name is not None: request.name = name + if delegates is not None: + request.delegates = delegates if payload is not None: request.payload = payload - if delegates: - request.delegates.extend(delegates) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.sign_blob] @@ -797,12 +790,11 @@ def sign_jwt( if name is not None: request.name = name + if delegates is not None: + request.delegates = delegates if payload is not None: request.payload = payload - if delegates: - request.delegates.extend(delegates) - # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.sign_jwt] diff --git a/google/cloud/iam_credentials_v1/services/iam_credentials/transports/base.py b/google/cloud/iam_credentials_v1/services/iam_credentials/transports/base.py index 9e5186e..7b87338 100644 --- a/google/cloud/iam_credentials_v1/services/iam_credentials/transports/base.py +++ b/google/cloud/iam_credentials_v1/services/iam_credentials/transports/base.py @@ -67,10 +67,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. @@ -78,6 +78,9 @@ def __init__( host += ":443" self._host = host + # Save the scopes. + self._scopes = scopes or self.AUTH_SCOPES + # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: @@ -87,20 +90,17 @@ def __init__( if credentials_file is not None: credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id + credentials_file, scopes=self._scopes, quota_project_id=quota_project_id ) elif credentials is None: credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id + scopes=self._scopes, quota_project_id=quota_project_id ) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - def _prep_wrapped_messages(self, client_info): # Precompute the wrapped methods. self._wrapped_methods = { @@ -113,6 +113,7 @@ def _prep_wrapped_messages(self, client_info): predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, @@ -126,6 +127,7 @@ def _prep_wrapped_messages(self, client_info): predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, @@ -139,6 +141,7 @@ def _prep_wrapped_messages(self, client_info): predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, @@ -152,6 +155,7 @@ def _prep_wrapped_messages(self, client_info): predicate=retries.if_exception_type( exceptions.DeadlineExceeded, exceptions.ServiceUnavailable, ), + deadline=60.0, ), default_timeout=60.0, client_info=client_info, diff --git a/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc.py b/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc.py index 8f74d2e..a60ed7f 100644 --- a/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc.py +++ b/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc.py @@ -66,6 +66,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -96,6 +97,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -110,72 +115,60 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials - else: - host = host if ":" in host else host + ":443" + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -183,17 +176,8 @@ def __init__( ], ) - self._stubs = {} # type: Dict[str, Callable] - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @classmethod def create_channel( @@ -207,7 +191,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If diff --git a/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc_asyncio.py b/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc_asyncio.py index 9cedf8a..026460d 100644 --- a/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc_asyncio.py +++ b/google/cloud/iam_credentials_v1/services/iam_credentials/transports/grpc_asyncio.py @@ -70,7 +70,7 @@ def create_channel( ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: - address (Optional[str]): The host for the channel to use. + host (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -110,6 +110,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -141,12 +142,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -155,72 +160,60 @@ def __init__( google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._grpc_channel = None self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. + # Ignore credentials if a channel was passed. credentials = False - # If a channel was explicitly provided, set it. self._grpc_channel = channel self._ssl_channel_credentials = None - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - self._ssl_channel_credentials = ssl_credentials - else: - host = host if ":" in host else host + ":443" + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + ) - # create a new channel. The provided one is ignored. + if not self._grpc_channel: self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, + self._host, + credentials=self._credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, quota_project_id=quota_project_id, options=[ ("grpc.max_send_message_length", -1), @@ -228,17 +221,8 @@ def __init__( ], ) - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) @property def grpc_channel(self) -> aio.Channel: diff --git a/google/cloud/iam_credentials_v1/types/__init__.py b/google/cloud/iam_credentials_v1/types/__init__.py index 936ab81..dbdc3bb 100644 --- a/google/cloud/iam_credentials_v1/types/__init__.py +++ b/google/cloud/iam_credentials_v1/types/__init__.py @@ -18,21 +18,21 @@ from .common import ( GenerateAccessTokenRequest, GenerateAccessTokenResponse, + GenerateIdTokenRequest, + GenerateIdTokenResponse, SignBlobRequest, SignBlobResponse, SignJwtRequest, SignJwtResponse, - GenerateIdTokenRequest, - GenerateIdTokenResponse, ) __all__ = ( "GenerateAccessTokenRequest", "GenerateAccessTokenResponse", + "GenerateIdTokenRequest", + "GenerateIdTokenResponse", "SignBlobRequest", "SignBlobResponse", "SignJwtRequest", "SignJwtResponse", - "GenerateIdTokenRequest", - "GenerateIdTokenResponse", ) diff --git a/synth.metadata b/synth.metadata index 938a7b3..11cdd32 100644 --- a/synth.metadata +++ b/synth.metadata @@ -11,8 +11,8 @@ "git": { "name": "googleapis", "remote": "https://github.com/googleapis/googleapis.git", - "sha": "520682435235d9c503983a360a2090025aa47cd1", - "internalRef": "350246057" + "sha": "149a3a84c29c9b8189576c7442ccb6dcf6a8f95b", + "internalRef": "364411656" } }, { diff --git a/tests/unit/gapic/credentials_v1/__init__.py b/tests/unit/gapic/credentials_v1/__init__.py index 8b13789..42ffdf2 100644 --- a/tests/unit/gapic/credentials_v1/__init__.py +++ b/tests/unit/gapic/credentials_v1/__init__.py @@ -1 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/credentials_v1/test_iam_credentials.py b/tests/unit/gapic/credentials_v1/test_iam_credentials.py index ede64da..e6fba4a 100644 --- a/tests/unit/gapic/credentials_v1/test_iam_credentials.py +++ b/tests/unit/gapic/credentials_v1/test_iam_credentials.py @@ -89,15 +89,19 @@ def test__get_default_mtls_endpoint(): ) -def test_iam_credentials_client_from_service_account_info(): +@pytest.mark.parametrize( + "client_class", [IAMCredentialsClient, IAMCredentialsAsyncClient,] +) +def test_iam_credentials_client_from_service_account_info(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( service_account.Credentials, "from_service_account_info" ) as factory: factory.return_value = creds info = {"valid": True} - client = IAMCredentialsClient.from_service_account_info(info) + client = client_class.from_service_account_info(info) assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "iamcredentials.googleapis.com:443" @@ -113,9 +117,11 @@ def test_iam_credentials_client_from_service_account_file(client_class): factory.return_value = creds client = client_class.from_service_account_file("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) client = client_class.from_service_account_json("dummy/file/path.json") assert client.transport._credentials == creds + assert isinstance(client, client_class) assert client.transport._host == "iamcredentials.googleapis.com:443" @@ -176,7 +182,7 @@ def test_iam_credentials_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -192,7 +198,7 @@ def test_iam_credentials_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +214,7 @@ def test_iam_credentials_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -236,7 +242,7 @@ def test_iam_credentials_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -287,29 +293,25 @@ def test_iam_credentials_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -318,66 +320,53 @@ def test_iam_credentials_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -403,7 +392,7 @@ def test_iam_credentials_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -433,7 +422,7 @@ def test_iam_credentials_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -452,7 +441,7 @@ def test_iam_credentials_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -497,6 +486,24 @@ def test_generate_access_token_from_dict(): test_generate_access_token(request_type=dict) +def test_generate_access_token_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IAMCredentialsClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_access_token), "__call__" + ) as call: + client.generate_access_token() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == common.GenerateAccessTokenRequest() + + @pytest.mark.asyncio async def test_generate_access_token_async( transport: str = "grpc_asyncio", request_type=common.GenerateAccessTokenRequest @@ -734,6 +741,24 @@ def test_generate_id_token_from_dict(): test_generate_id_token(request_type=dict) +def test_generate_id_token_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IAMCredentialsClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_id_token), "__call__" + ) as call: + client.generate_id_token() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == common.GenerateIdTokenRequest() + + @pytest.mark.asyncio async def test_generate_id_token_async( transport: str = "grpc_asyncio", request_type=common.GenerateIdTokenRequest @@ -967,6 +992,22 @@ def test_sign_blob_from_dict(): test_sign_blob(request_type=dict) +def test_sign_blob_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IAMCredentialsClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.sign_blob), "__call__") as call: + client.sign_blob() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == common.SignBlobRequest() + + @pytest.mark.asyncio async def test_sign_blob_async( transport: str = "grpc_asyncio", request_type=common.SignBlobRequest @@ -1182,6 +1223,22 @@ def test_sign_jwt_from_dict(): test_sign_jwt(request_type=dict) +def test_sign_jwt_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = IAMCredentialsClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.sign_jwt), "__call__") as call: + client.sign_jwt() + call.assert_called() + _, args, _ = call.mock_calls[0] + + assert args[0] == common.SignJwtRequest() + + @pytest.mark.asyncio async def test_sign_jwt_async( transport: str = "grpc_asyncio", request_type=common.SignJwtRequest @@ -1522,6 +1579,51 @@ def test_iam_credentials_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.IAMCredentialsGrpcTransport, + transports.IAMCredentialsGrpcAsyncIOTransport, + ], +) +def test_iam_credentials_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_iam_credentials_host_no_port(): client = IAMCredentialsClient( credentials=credentials.AnonymousCredentials(), @@ -1566,6 +1668,8 @@ def test_iam_credentials_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1618,6 +1722,8 @@ def test_iam_credentials_transport_channel_mtls_with_client_cert_source( assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [