diff --git a/docs/index.rst b/docs/index.rst index 4287c3db3..17169109a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,6 +20,8 @@ also provides integration with several HTTP libraries. - Support for Google :mod:`Impersonated Credentials `. - Support for :mod:`Google Compute Engine credentials `. - Support for :mod:`Google App Engine standard credentials `. +- Support for :mod:`Identity Pool credentials `. +- Support for :mod:`AWS credentials `. - Support for various transports, including :mod:`Requests `, :mod:`urllib3 `, and diff --git a/docs/reference/google.auth.aws.rst b/docs/reference/google.auth.aws.rst new file mode 100644 index 000000000..9c3966bba --- /dev/null +++ b/docs/reference/google.auth.aws.rst @@ -0,0 +1,7 @@ +google.auth.aws module +====================== + +.. automodule:: google.auth.aws + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.external_account.rst b/docs/reference/google.auth.external_account.rst new file mode 100644 index 000000000..0681eaa27 --- /dev/null +++ b/docs/reference/google.auth.external_account.rst @@ -0,0 +1,7 @@ +google.auth.external\_account module +==================================== + +.. automodule:: google.auth.external_account + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.identity_pool.rst b/docs/reference/google.auth.identity_pool.rst new file mode 100644 index 000000000..48d990223 --- /dev/null +++ b/docs/reference/google.auth.identity_pool.rst @@ -0,0 +1,7 @@ +google.auth.identity\_pool module +================================= + +.. automodule:: google.auth.identity_pool + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.rst b/docs/reference/google.auth.rst index 3acf7dfb8..e21eaf9e3 100644 --- a/docs/reference/google.auth.rst +++ b/docs/reference/google.auth.rst @@ -23,11 +23,14 @@ Submodules :maxdepth: 4 google.auth.app_engine + google.auth.aws google.auth.credentials google.auth._credentials_async google.auth.environment_vars google.auth.exceptions + google.auth.external_account google.auth.iam + google.auth.identity_pool google.auth.impersonated_credentials google.auth.jwt google.auth.jwt_async diff --git a/docs/reference/google.oauth2.rst b/docs/reference/google.oauth2.rst index 6f3ba50c2..2a8a7a588 100644 --- a/docs/reference/google.oauth2.rst +++ b/docs/reference/google.oauth2.rst @@ -17,3 +17,5 @@ Submodules google.oauth2.id_token google.oauth2.service_account google.oauth2._service_account_async + google.oauth2.sts + google.oauth2.utils diff --git a/docs/reference/google.oauth2.sts.rst b/docs/reference/google.oauth2.sts.rst new file mode 100644 index 000000000..49d99dfe6 --- /dev/null +++ b/docs/reference/google.oauth2.sts.rst @@ -0,0 +1,7 @@ +google.oauth2.sts module +======================== + +.. automodule:: google.oauth2.sts + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.oauth2.utils.rst b/docs/reference/google.oauth2.utils.rst new file mode 100644 index 000000000..5b039eac8 --- /dev/null +++ b/docs/reference/google.oauth2.utils.rst @@ -0,0 +1,7 @@ +google.oauth2.utils module +========================== + +.. automodule:: google.oauth2.utils + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/user-guide.rst b/docs/user-guide.rst index 08e7167df..7332bd48e 100644 --- a/docs/user-guide.rst +++ b/docs/user-guide.rst @@ -7,8 +7,8 @@ Credentials and account types ----------------------------- :class:`~credentials.Credentials` are the means of identifying an application or -user to a service or API. Credentials can be obtained with two different types -of accounts: *service accounts* and *user accounts*. +user to a service or API. Credentials can be obtained with three different types +of accounts: *service accounts*, *user accounts* and *external accounts*. Credentials from service accounts identify a particular application. These types of credentials are used in server-to-server use cases, such as accessing a @@ -21,6 +21,11 @@ a user's documents in Google Drive. This library provides no support for obtaining user credentials, but does provide limited support for using user credentials. +Credentials from external accounts (workload identity federation) are used to +identify a particular application from an on-prem or non-Google Cloud platform +including Amazon Web Services (AWS), Microsoft Azure or any identity provider +that supports OpenID Connect (OIDC). + Obtaining credentials --------------------- @@ -44,6 +49,13 @@ If your application requires specific scopes:: credentials, project = google.auth.default( scopes=['https://www.googleapis.com/auth/cloud-platform']) +Application Default Credentials also support workload identity federation to +access Google Cloud resources from non-Google Cloud platforms including Amazon +Web Services (AWS), Microsoft Azure or any identity provider that supports +OpenID Connect (OIDC). Workload identity federation is recommended for +non-Google Cloud environments as it avoids the need to download, manage and +store service account private keys locally. + .. _Google Application Default Credentials: https://developers.google.com/identity/protocols/ application-default-credentials @@ -219,6 +231,163 @@ You can also use :class:`google_auth_oauthlib.flow.Flow` to perform the OAuth .. _requests-oauthlib: https://requests-oauthlib.readthedocs.io/en/latest/ +External credentials (Workload identity federation) ++++++++++++++++++++++++++++++++++++++++++++++++++++ + +Using workload identity federation, your application can access Google Cloud +resources from Amazon Web Services (AWS), Microsoft Azure or any identity +provider that supports OpenID Connect (OIDC). + +Traditionally, applications running outside Google Cloud have used service +account keys to access Google Cloud resources. Using identity federation, +you can allow your workload to impersonate a service account. +This lets you access Google Cloud resources directly, eliminating the +maintenance and security burden associated with service account keys. + +Accessing resources from AWS +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In order to access Google Cloud resources from Amazon Web Services (AWS), the +following requirements are needed: + +- A workload identity pool needs to be created. +- AWS needs to be added as an identity provider in the workload identity pool + (The Google organization policy needs to allow federation from AWS). +- Permission to impersonate a service account needs to be granted to the + external identity. +- A credential configuration file needs to be generated. Unlike service account + credential files, the generated credential configuration file will only + contain non-sensitive metadata to instruct the library on how to retrieve + external subject tokens and exchange them for service account access tokens. + +Follow the detailed instructions on how to +`Configure Workload Identity Federation from AWS`_. + +.. _Configure Workload Identity Federation from AWS: + https://cloud.google.com/iam/docs/access-resources-aws + +Accessing resources from Microsoft Azure +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In order to access Google Cloud resources from Microsoft Azure, the following +requirements are needed: + +- A workload identity pool needs to be created. +- Azure needs to be added as an identity provider in the workload identity pool + (The Google organization policy needs to allow federation from Azure). +- The Azure tenant needs to be configured for identity federation. +- Permission to impersonate a service account needs to be granted to the + external identity. +- A credential configuration file needs to be generated. Unlike service account + credential files, the generated credential configuration file will only + contain non-sensitive metadata to instruct the library on how to retrieve + external subject tokens and exchange them for service account access tokens. + +Follow the detailed instructions on how to +`Configure Workload Identity Federation from Microsoft Azure`_. + +.. _Configure Workload Identity Federation from Microsoft Azure: + https://cloud.google.com/iam/docs/access-resources-azure + +Accessing resources from an OIDC identity provider +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In order to access Google Cloud resources from an identity provider that +supports `OpenID Connect (OIDC)`_, the following requirements are needed: + +- A workload identity pool needs to be created. +- An OIDC identity provider needs to be added in the workload identity pool + (The Google organization policy needs to allow federation from the identity + provider). +- Permission to impersonate a service account needs to be granted to the + external identity. +- A credential configuration file needs to be generated. Unlike service account + credential files, the generated credential configuration file will only + contain non-sensitive metadata to instruct the library on how to retrieve + external subject tokens and exchange them for service account access tokens. + +For OIDC providers, the Auth library can retrieve OIDC tokens either from a +local file location (file-sourced credentials) or from a local server +(URL-sourced credentials). + +- For file-sourced credentials, a background process needs to be continuously + refreshing the file location with a new OIDC token prior to expiration. + For tokens with one hour lifetimes, the token needs to be updated in the file + every hour. The token can be stored directly as plain text or in JSON format. +- For URL-sourced credentials, a local server needs to host a GET endpoint to + return the OIDC token. The response can be in plain text or JSON. + Additional required request headers can also be specified. + +Follow the detailed instructions on how to +`Configure Workload Identity Federation from an OIDC identity provider`_. + +.. _OpenID Connect (OIDC): + https://openid.net/connect/ +.. _Configure Workload Identity Federation from an OIDC identity provider: + https://cloud.google.com/iam/docs/access-resources-oidc + +Using External Identities +~~~~~~~~~~~~~~~~~~~~~~~~~ + +External identities (AWS, Azure and OIDC identity providers) can be used with +Application Default Credentials. +In order to use external identities with Application Default Credentials, you +need to generate the JSON credentials configuration file for your external +identity. +Once generated, store the path to this file in the +``GOOGLE_APPLICATION_CREDENTIALS`` environment variable. + +.. code-block:: bash + + $ export GOOGLE_APPLICATION_CREDENTIALS=/path/to/config.json + +The library can now automatically choose the right type of client and initialize +credentials from the context provided in the configuration file:: + + import google.auth + + credentials, project = google.auth.default() + +When using external identities with Application Default Credentials, +the ``roles/browser`` role needs to be granted to the service account. +The ``Cloud Resource Manager API`` should also be enabled on the project. +This is needed since :func:`default` will try to auto-discover the project ID +from the current environment using the impersonated credential. +Otherwise, the project ID will resolve to ``None``. You can override the project +detection by setting the ``GOOGLE_CLOUD_PROJECT`` environment variable. + +You can also explicitly initialize external account clients using the generated +configuration file. + +For Azure and OIDC providers, use :meth:`identity_pool.Credentials.from_info +` or +:meth:`identity_pool.Credentials.from_file +`:: + + import json + + from google.auth import identity_pool + + json_config_info = json.loads(function_to_get_json_config()) + credentials = identity_pool.Credentials.from_info(json_config_info) + scoped_credentials = credentials.with_scopes( + ['https://www.googleapis.com/auth/cloud-platform']) + +For AWS providers, use :meth:`aws.Credentials.from_info +` or +:meth:`aws.Credentials.from_file +`:: + + import json + + from google.auth import aws + + json_config_info = json.loads(function_to_get_json_config()) + credentials = aws.Credentials.from_info(json_config_info) + scoped_credentials = credentials.with_scopes( + ['https://www.googleapis.com/auth/cloud-platform']) + + Impersonated credentials ++++++++++++++++++++++++ diff --git a/google/auth/_default.py b/google/auth/_default.py index 3b8c281e7..836c33915 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -34,7 +34,8 @@ # Valid types accepted for file-based credentials. _AUTHORIZED_USER_TYPE = "authorized_user" _SERVICE_ACCOUNT_TYPE = "service_account" -_VALID_TYPES = (_AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE) +_EXTERNAL_ACCOUNT_TYPE = "external_account" +_VALID_TYPES = (_AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE, _EXTERNAL_ACCOUNT_TYPE) # Help message when no credentials can be found. _HELP_MESSAGE = """\ @@ -70,12 +71,12 @@ def _warn_about_problematic_credentials(credentials): def load_credentials_from_file( - filename, scopes=None, default_scopes=None, quota_project_id=None + filename, scopes=None, default_scopes=None, quota_project_id=None, request=None ): """Loads Google credentials from a file. - The credentials file must be a service account key or stored authorized - user credentials. + The credentials file must be a service account key, stored authorized + user credentials or external account credentials. Args: filename (str): The full path to the credentials file. @@ -85,12 +86,18 @@ def load_credentials_from_file( default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. quota_project_id (Optional[str]): The project ID used for - quota and billing. + quota and billing. + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to determine the associated project ID + for a workload identity pool resource (external account credentials). + If not specified, then it will use a + google.auth.transport.requests.Request client to make requests. Returns: Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded credentials and the project ID. Authorized user credentials do not - have the project ID information. + have the project ID information. External account credentials project + IDs may not always be determined. Raises: google.auth.exceptions.DefaultCredentialsError: if the file is in the @@ -146,6 +153,18 @@ def load_credentials_from_file( credentials = credentials.with_quota_project(quota_project_id) return credentials, info.get("project_id") + elif credential_type == _EXTERNAL_ACCOUNT_TYPE: + credentials, project_id = _get_external_account_credentials( + info, + filename, + scopes=scopes, + default_scopes=default_scopes, + request=request, + ) + if quota_project_id: + credentials = credentials.with_quota_project(quota_project_id) + return credentials, project_id + else: raise exceptions.DefaultCredentialsError( "The file {file} does not have a valid type. " @@ -176,9 +195,28 @@ def _get_gcloud_sdk_credentials(): return credentials, project_id -def _get_explicit_environ_credentials(): +def _get_explicit_environ_credentials(request=None, scopes=None, default_scopes=None): """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment - variable.""" + variable. + + Args: + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to determine the associated project ID + for a workload identity pool resource (external account credentials). + If not specified, then it will use a + google.auth.transport.requests.Request client to make requests. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + + Returns: + Tuple[Optional[google.auth.credentials.Credentials], Optional[str]]: Loaded + credentials and the project ID. Authorized user credentials do not + have the project ID information. External account credentials project + IDs may not always be determined. + """ explicit_file = os.environ.get(environment_vars.CREDENTIALS) _LOGGER.debug( @@ -187,7 +225,11 @@ def _get_explicit_environ_credentials(): if explicit_file is not None: credentials, project_id = load_credentials_from_file( - os.environ[environment_vars.CREDENTIALS] + os.environ[environment_vars.CREDENTIALS], + scopes=scopes, + default_scopes=default_scopes, + quota_project_id=None, + request=request, ) return credentials, project_id @@ -252,6 +294,65 @@ def _get_gce_credentials(request=None): return None, None +def _get_external_account_credentials( + info, filename, scopes=None, default_scopes=None, request=None +): + """Loads external account Credentials from the parsed external account info. + + The credentials information must correspond to a supported external account + credentials. + + Args: + info (Mapping[str, str]): The external account info in Google format. + filename (str): The full path to the credentials file. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to determine the associated project ID + for a workload identity pool resource (external account credentials). + If not specified, then it will use a + google.auth.transport.requests.Request client to make requests. + + Returns: + Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded + credentials and the project ID. External account credentials project + IDs may not always be determined. + + Raises: + google.auth.exceptions.DefaultCredentialsError: if the info dictionary + is in the wrong format or is missing required information. + """ + # There are currently 2 types of external_account credentials. + try: + # Check if configuration corresponds to an AWS credentials. + from google.auth import aws + + credentials = aws.Credentials.from_info( + info, scopes=scopes, default_scopes=default_scopes + ) + except ValueError: + try: + # Check if configuration corresponds to an Identity Pool credentials. + from google.auth import identity_pool + + credentials = identity_pool.Credentials.from_info( + info, scopes=scopes, default_scopes=default_scopes + ) + except ValueError: + # If the configuration is invalid or does not correspond to any + # supported external_account credentials, raise an error. + raise exceptions.DefaultCredentialsError( + "Failed to load external account credentials from {}".format(filename) + ) + if request is None: + request = google.auth.transport.requests.Request() + + return credentials, credentials.get_project_id(request=request) + + def default(scopes=None, request=None, quota_project_id=None, default_scopes=None): """Gets the default credentials for the current environment. @@ -265,6 +366,15 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non loaded and returned. The project ID returned is the project ID defined in the service account file if available (some older files do not contain project ID information). + + If the environment variable is set to the path of a valid external + account JSON configuration file (workload identity federation), then the + configuration file is used to determine and retrieve the external + credentials from the current environment (AWS, Azure, etc). + These will then be exchanged for Google access tokens via the Google STS + endpoint. + The project ID returned in this case is the one corresponding to the + underlying workload identity pool resource if determinable. 2. If the `Google Cloud SDK`_ is installed and has application default credentials set they are loaded and returned. @@ -310,11 +420,15 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non scopes (Sequence[str]): The list of scopes for the credentials. If specified, the credentials will automatically be scoped if necessary. - request (google.auth.transport.Request): An object used to make - HTTP requests. This is used to detect whether the application - is running on Compute Engine. If not specified, then it will - use the standard library http client to make requests. - quota_project_id (Optional[str]): The project ID used for + request (Optional[google.auth.transport.Request]): An object used to make + HTTP requests. This is used to either detect whether the application + is running on Compute Engine or to determine the associated project + ID for a workload identity pool resource (external account + credentials). If not specified, then it will either use the standard + library http client to make requests for Compute Engine credentials + or a google.auth.transport.requests.Request client for external + account credentials. + quota_project_id (Optional[str]): The project ID used for quota and billing. default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. @@ -336,7 +450,9 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non ) checkers = ( - _get_explicit_environ_credentials, + lambda: _get_explicit_environ_credentials( + request=request, scopes=scopes, default_scopes=default_scopes + ), _get_gcloud_sdk_credentials, _get_gae_credentials, lambda: _get_gce_credentials(request), diff --git a/google/auth/aws.py b/google/auth/aws.py new file mode 100644 index 000000000..b362dd315 --- /dev/null +++ b/google/auth/aws.py @@ -0,0 +1,714 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AWS Credentials and AWS Signature V4 Request Signer. + +This module provides credentials to access Google Cloud resources from Amazon +Web Services (AWS) workloads. These credentials are recommended over the +use of service account credentials in AWS as they do not involve the management +of long-live service account private keys. + +AWS Credentials are initialized using external_account arguments which are +typically loaded from the external credentials JSON file. +Unlike other Credentials that can be initialized with a list of explicit +arguments, secrets or credentials, external account clients use the +environment and hints/guidelines provided by the external_account JSON +file to retrieve credentials and exchange them for Google access tokens. + +This module also provides a basic implementation of the +`AWS Signature Version 4`_ request signing algorithm. + +AWS Credentials use serialized signed requests to the +`AWS STS GetCallerIdentity`_ API that can be exchanged for Google access tokens +via the GCP STS endpoint. + +.. _AWS Signature Version 4: https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html +.. _AWS STS GetCallerIdentity: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html +""" + +import hashlib +import hmac +import io +import json +import os +import re + +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import environment_vars +from google.auth import exceptions +from google.auth import external_account + +# AWS Signature Version 4 signing algorithm identifier. +_AWS_ALGORITHM = "AWS4-HMAC-SHA256" +# The termination string for the AWS credential scope value as defined in +# https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html +_AWS_REQUEST_TYPE = "aws4_request" +# The AWS authorization header name for the security session token if available. +_AWS_SECURITY_TOKEN_HEADER = "x-amz-security-token" +# The AWS authorization header name for the auto-generated date. +_AWS_DATE_HEADER = "x-amz-date" + + +class RequestSigner(object): + """Implements an AWS request signer based on the AWS Signature Version 4 signing + process. + https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html + """ + + def __init__(self, region_name): + """Instantiates an AWS request signer used to compute authenticated signed + requests to AWS APIs based on the AWS Signature Version 4 signing process. + + Args: + region_name (str): The AWS region to use. + """ + + self._region_name = region_name + + def get_request_options( + self, + aws_security_credentials, + url, + method, + request_payload="", + additional_headers={}, + ): + """Generates the signed request for the provided HTTP request for calling + an AWS API. This follows the steps described at: + https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html + + Args: + aws_security_credentials (Mapping[str, str]): A dictionary containing + the AWS security credentials. + url (str): The AWS service URL containing the canonical URI and + query string. + method (str): The HTTP method used to call this API. + request_payload (Optional[str]): The optional request payload if + available. + additional_headers (Optional[Mapping[str, str]]): The optional + additional headers needed for the requested AWS API. + + Returns: + Mapping[str, str]: The AWS signed request dictionary object. + """ + # Get AWS credentials. + access_key = aws_security_credentials.get("access_key_id") + secret_key = aws_security_credentials.get("secret_access_key") + security_token = aws_security_credentials.get("security_token") + + additional_headers = additional_headers or {} + + uri = urllib.parse.urlparse(url) + # Validate provided URL. + if not uri.hostname or uri.scheme != "https": + raise ValueError("Invalid AWS service URL") + + header_map = _generate_authentication_header_map( + host=uri.hostname, + canonical_uri=os.path.normpath(uri.path or "/"), + canonical_querystring=_get_canonical_querystring(uri.query), + method=method, + region=self._region_name, + access_key=access_key, + secret_key=secret_key, + security_token=security_token, + request_payload=request_payload, + additional_headers=additional_headers, + ) + headers = { + "Authorization": header_map.get("authorization_header"), + "host": uri.hostname, + } + # Add x-amz-date if available. + if "amz_date" in header_map: + headers[_AWS_DATE_HEADER] = header_map.get("amz_date") + # Append additional optional headers, eg. X-Amz-Target, Content-Type, etc. + for key in additional_headers: + headers[key] = additional_headers[key] + + # Add session token if available. + if security_token is not None: + headers[_AWS_SECURITY_TOKEN_HEADER] = security_token + + signed_request = {"url": url, "method": method, "headers": headers} + if request_payload: + signed_request["data"] = request_payload + return signed_request + + +def _get_canonical_querystring(query): + """Generates the canonical query string given a raw query string. + Logic is based on + https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + + Args: + query (str): The raw query string. + + Returns: + str: The canonical query string. + """ + # Parse raw query string. + querystring = urllib.parse.parse_qs(query) + querystring_encoded_map = {} + for key in querystring: + quote_key = urllib.parse.quote(key, safe="-_.~") + # URI encode key. + querystring_encoded_map[quote_key] = [] + for item in querystring[key]: + # For each key, URI encode all values for that key. + querystring_encoded_map[quote_key].append( + urllib.parse.quote(item, safe="-_.~") + ) + # Sort values for each key. + querystring_encoded_map[quote_key].sort() + # Sort keys. + sorted_keys = list(querystring_encoded_map.keys()) + sorted_keys.sort() + # Reconstruct the query string. Preserve keys with multiple values. + querystring_encoded_pairs = [] + for key in sorted_keys: + for item in querystring_encoded_map[key]: + querystring_encoded_pairs.append("{}={}".format(key, item)) + return "&".join(querystring_encoded_pairs) + + +def _sign(key, msg): + """Creates the HMAC-SHA256 hash of the provided message using the provided + key. + + Args: + key (str): The HMAC-SHA256 key to use. + msg (str): The message to hash. + + Returns: + str: The computed hash bytes. + """ + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + +def _get_signing_key(key, date_stamp, region_name, service_name): + """Calculates the signing key used to calculate the signature for + AWS Signature Version 4 based on: + https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + + Args: + key (str): The AWS secret access key. + date_stamp (str): The '%Y%m%d' date format. + region_name (str): The AWS region. + service_name (str): The AWS service name, eg. sts. + + Returns: + str: The signing key bytes. + """ + k_date = _sign(("AWS4" + key).encode("utf-8"), date_stamp) + k_region = _sign(k_date, region_name) + k_service = _sign(k_region, service_name) + k_signing = _sign(k_service, "aws4_request") + return k_signing + + +def _generate_authentication_header_map( + host, + canonical_uri, + canonical_querystring, + method, + region, + access_key, + secret_key, + security_token, + request_payload="", + additional_headers={}, +): + """Generates the authentication header map needed for generating the AWS + Signature Version 4 signed request. + + Args: + host (str): The AWS service URL hostname. + canonical_uri (str): The AWS service URL path name. + canonical_querystring (str): The AWS service URL query string. + method (str): The HTTP method used to call this API. + region (str): The AWS region. + access_key (str): The AWS access key ID. + secret_key (str): The AWS secret access key. + security_token (Optional[str]): The AWS security session token. This is + available for temporary sessions. + request_payload (Optional[str]): The optional request payload if + available. + additional_headers (Optional[Mapping[str, str]]): The optional + additional headers needed for the requested AWS API. + + Returns: + Mapping[str, str]: The AWS authentication header dictionary object. + This contains the x-amz-date and authorization header information. + """ + # iam.amazonaws.com host => iam service. + # sts.us-east-2.amazonaws.com host => sts service. + service_name = host.split(".")[0] + + current_time = _helpers.utcnow() + amz_date = current_time.strftime("%Y%m%dT%H%M%SZ") + date_stamp = current_time.strftime("%Y%m%d") + + # Change all additional headers to be lower case. + full_headers = {} + for key in additional_headers: + full_headers[key.lower()] = additional_headers[key] + # Add AWS session token if available. + if security_token is not None: + full_headers[_AWS_SECURITY_TOKEN_HEADER] = security_token + + # Required headers + full_headers["host"] = host + # Do not use generated x-amz-date if the date header is provided. + # Previously the date was not fixed with x-amz- and could be provided + # manually. + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-header-value-trim.req + if "date" not in full_headers: + full_headers[_AWS_DATE_HEADER] = amz_date + + # Header keys need to be sorted alphabetically. + canonical_headers = "" + header_keys = list(full_headers.keys()) + header_keys.sort() + for key in header_keys: + canonical_headers = "{}{}:{}\n".format( + canonical_headers, key, full_headers[key] + ) + signed_headers = ";".join(header_keys) + + payload_hash = hashlib.sha256((request_payload or "").encode("utf-8")).hexdigest() + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + canonical_request = "{}\n{}\n{}\n{}\n{}\n{}".format( + method, + canonical_uri, + canonical_querystring, + canonical_headers, + signed_headers, + payload_hash, + ) + + credential_scope = "{}/{}/{}/{}".format( + date_stamp, region, service_name, _AWS_REQUEST_TYPE + ) + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html + string_to_sign = "{}\n{}\n{}\n{}".format( + _AWS_ALGORITHM, + amz_date, + credential_scope, + hashlib.sha256(canonical_request.encode("utf-8")).hexdigest(), + ) + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html + signing_key = _get_signing_key(secret_key, date_stamp, region, service_name) + signature = hmac.new( + signing_key, string_to_sign.encode("utf-8"), hashlib.sha256 + ).hexdigest() + + # https://docs.aws.amazon.com/general/latest/gr/sigv4-add-signature-to-request.html + authorization_header = "{} Credential={}/{}, SignedHeaders={}, Signature={}".format( + _AWS_ALGORITHM, access_key, credential_scope, signed_headers, signature + ) + + authentication_header = {"authorization_header": authorization_header} + # Do not use generated x-amz-date if the date header is provided. + if "date" not in full_headers: + authentication_header["amz_date"] = amz_date + return authentication_header + + +class Credentials(external_account.Credentials): + """AWS external account credentials. + This is used to exchange serialized AWS signature v4 signed requests to + AWS STS GetCallerIdentity service for Google access tokens. + """ + + def __init__( + self, + audience, + subject_token_type, + token_url, + credential_source=None, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + ): + """Instantiates an AWS workload external account credentials object. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type. + token_url (str): The STS endpoint URL. + credential_source (Mapping): The credential source dictionary used + to provide instructions on how to retrieve external credential + to be exchanged for Google access tokens. + service_account_impersonation_url (Optional[str]): The optional + service account impersonation getAccessToken URL. + client_id (Optional[str]): The optional client ID. + client_secret (Optional[str]): The optional client secret. + quota_project_id (Optional[str]): The optional quota project ID. + scopes (Optional[Sequence[str]]): Optional scopes to request during + the authorization grant. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + access token retrieval logic. + ValueError: For invalid parameters. + + .. note:: Typically one of the helper constructors + :meth:`from_file` or + :meth:`from_info` are used instead of calling the constructor directly. + """ + super(Credentials, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + service_account_impersonation_url=service_account_impersonation_url, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + ) + credential_source = credential_source or {} + self._environment_id = credential_source.get("environment_id") or "" + self._region_url = credential_source.get("region_url") + self._security_credentials_url = credential_source.get("url") + self._cred_verification_url = credential_source.get( + "regional_cred_verification_url" + ) + self._region = None + self._request_signer = None + self._target_resource = audience + + # Get the environment ID. Currently, only one version supported (v1). + matches = re.match(r"^(aws)([\d]+)$", self._environment_id) + if matches: + env_id, env_version = matches.groups() + else: + env_id, env_version = (None, None) + + if env_id != "aws" or self._cred_verification_url is None: + raise ValueError("No valid AWS 'credential_source' provided") + elif int(env_version or "") != 1: + raise ValueError( + "aws version '{}' is not supported in the current build.".format( + env_version + ) + ) + + def retrieve_subject_token(self, request): + """Retrieves the subject token using the credential_source object. + The subject token is a serialized `AWS GetCallerIdentity signed request`_. + + The logic is summarized as: + + Retrieve the AWS region from the AWS_REGION environment variable or from + the AWS metadata server availability-zone if not found in the + environment variable. + + Check AWS credentials in environment variables. If not found, retrieve + from the AWS metadata server security-credentials endpoint. + + When retrieving AWS credentials from the metadata server + security-credentials endpoint, the AWS role needs to be determined by + calling the security-credentials endpoint without any argument. Then the + credentials can be retrieved via: security-credentials/role_name + + Generate the signed request to AWS STS GetCallerIdentity action. + + Inject x-goog-cloud-target-resource into header and serialize the + signed request. This will be the subject-token to pass to GCP STS. + + .. _AWS GetCallerIdentity signed request: + https://cloud.google.com/iam/docs/access-resources-aws#exchange-token + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + str: The retrieved subject token. + """ + # Initialize the request signer if not yet initialized after determining + # the current AWS region. + if self._request_signer is None: + self._region = self._get_region(request, self._region_url) + self._request_signer = RequestSigner(self._region) + + # Retrieve the AWS security credentials needed to generate the signed + # request. + aws_security_credentials = self._get_security_credentials(request) + # Generate the signed request to AWS STS GetCallerIdentity API. + # Use the required regional endpoint. Otherwise, the request will fail. + request_options = self._request_signer.get_request_options( + aws_security_credentials, + self._cred_verification_url.replace("{region}", self._region), + "POST", + ) + # The GCP STS endpoint expects the headers to be formatted as: + # [ + # {key: 'x-amz-date', value: '...'}, + # {key: 'Authorization', value: '...'}, + # ... + # ] + # And then serialized as: + # quote(json.dumps({ + # url: '...', + # method: 'POST', + # headers: [{key: 'x-amz-date', value: '...'}, ...] + # })) + request_headers = request_options.get("headers") + # The full, canonical resource name of the workload identity pool + # provider, with or without the HTTPS prefix. + # Including this header as part of the signature is recommended to + # ensure data integrity. + request_headers["x-goog-cloud-target-resource"] = self._target_resource + + # Serialize AWS signed request. + # Keeping inner keys in sorted order makes testing easier for Python + # versions <=3.5 as the stringified JSON string would have a predictable + # key order. + aws_signed_req = {} + aws_signed_req["url"] = request_options.get("url") + aws_signed_req["method"] = request_options.get("method") + aws_signed_req["headers"] = [] + # Reformat header to GCP STS expected format. + for key in sorted(request_headers.keys()): + aws_signed_req["headers"].append( + {"key": key, "value": request_headers[key]} + ) + + return urllib.parse.quote( + json.dumps(aws_signed_req, separators=(",", ":"), sort_keys=True) + ) + + def _get_region(self, request, url): + """Retrieves the current AWS region from either the AWS_REGION + environment variable or from the AWS metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + url (str): The AWS metadata server region URL. + + Returns: + str: The current AWS region. + + Raises: + google.auth.exceptions.RefreshError: If an error occurs while + retrieving the AWS region. + """ + # The AWS metadata server is not available in some AWS environments + # such as AWS lambda. Instead, it is available via environment + # variable. + env_aws_region = os.environ.get(environment_vars.AWS_REGION) + if env_aws_region is not None: + return env_aws_region + + if not self._region_url: + raise exceptions.RefreshError("Unable to determine AWS region") + response = request(url=self._region_url, method="GET") + + # Support both string and bytes type response.data. + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve AWS region", response_body + ) + + # This endpoint will return the region in format: us-east-2b. + # Only the us-east-2 part should be used. + return response_body[:-1] + + def _get_security_credentials(self, request): + """Retrieves the AWS security credentials required for signing AWS + requests from either the AWS security credentials environment variables + or from the AWS metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + + Returns: + Mapping[str, str]: The AWS security credentials dictionary object. + + Raises: + google.auth.exceptions.RefreshError: If an error occurs while + retrieving the AWS security credentials. + """ + + # Check environment variables for permanent credentials first. + # https://docs.aws.amazon.com/general/latest/gr/aws-sec-cred-types.html + env_aws_access_key_id = os.environ.get(environment_vars.AWS_ACCESS_KEY_ID) + env_aws_secret_access_key = os.environ.get( + environment_vars.AWS_SECRET_ACCESS_KEY + ) + # This is normally not available for permanent credentials. + env_aws_session_token = os.environ.get(environment_vars.AWS_SESSION_TOKEN) + if env_aws_access_key_id and env_aws_secret_access_key: + return { + "access_key_id": env_aws_access_key_id, + "secret_access_key": env_aws_secret_access_key, + "security_token": env_aws_session_token, + } + + # Get role name. + role_name = self._get_metadata_role_name(request) + + # Get security credentials. + credentials = self._get_metadata_security_credentials(request, role_name) + + return { + "access_key_id": credentials.get("AccessKeyId"), + "secret_access_key": credentials.get("SecretAccessKey"), + "security_token": credentials.get("Token"), + } + + def _get_metadata_security_credentials(self, request, role_name): + """Retrieves the AWS security credentials required for signing AWS + requests from the AWS metadata server. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + role_name (str): The AWS role name required by the AWS metadata + server security_credentials endpoint in order to return the + credentials. + + Returns: + Mapping[str, str]: The AWS metadata server security credentials + response. + + Raises: + google.auth.exceptions.RefreshError: If an error occurs while + retrieving the AWS security credentials. + """ + headers = {"Content-Type": "application/json"} + response = request( + url="{}/{}".format(self._security_credentials_url, role_name), + method="GET", + headers=headers, + ) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != http_client.OK: + raise exceptions.RefreshError( + "Unable to retrieve AWS security credentials", response_body + ) + + credentials_response = json.loads(response_body) + + return credentials_response + + def _get_metadata_role_name(self, request): + """Retrieves the AWS role currently attached to the current AWS + workload by querying the AWS metadata server. This is needed for the + AWS metadata server security credentials endpoint in order to retrieve + the AWS security credentials needed to sign requests to AWS APIs. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + + Returns: + str: The AWS role name. + + Raises: + google.auth.exceptions.RefreshError: If an error occurs while + retrieving the AWS role name. + """ + if self._security_credentials_url is None: + raise exceptions.RefreshError( + "Unable to determine the AWS metadata server security credentials endpoint" + ) + response = request(url=self._security_credentials_url, method="GET") + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != http_client.OK: + raise exceptions.RefreshError( + "Unable to retrieve AWS role name", response_body + ) + + return response_body + + @classmethod + def from_info(cls, info, **kwargs): + """Creates an AWS Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The AWS external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.aws.Credentials: The constructed credentials. + + Raises: + ValueError: For invalid parameters. + """ + return cls( + audience=info.get("audience"), + subject_token_type=info.get("subject_token_type"), + token_url=info.get("token_url"), + service_account_impersonation_url=info.get( + "service_account_impersonation_url" + ), + client_id=info.get("client_id"), + client_secret=info.get("client_secret"), + credential_source=info.get("credential_source"), + quota_project_id=info.get("quota_project_id"), + **kwargs + ) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates an AWS Credentials instance from an external account json file. + + Args: + filename (str): The path to the AWS external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.aws.Credentials: The constructed credentials. + """ + with io.open(filename, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + return cls.from_info(data, **kwargs) diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index 46a892664..416bab0c0 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -59,3 +59,13 @@ The default value is false. Users have to explicitly set this value to true in order to use client certificate to establish a mutual TLS channel.""" + +# AWS environment variables used with AWS workload identity pools to retrieve +# AWS security credentials and the AWS region needed to create a serialized +# signed requests to the AWS STS GetCalledIdentity API that can be exchanged +# for a Google access tokens via the GCP STS endpoint. +# When not available the AWS metadata server is used to retrieve these values. +AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" +AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" +AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" +AWS_REGION = "AWS_REGION" diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py index da06d8696..b6f686bbb 100644 --- a/google/auth/exceptions.py +++ b/google/auth/exceptions.py @@ -43,3 +43,8 @@ class MutualTLSChannelError(GoogleAuthError): class ClientCertError(GoogleAuthError): """Used to indicate that client certificate is missing or invalid.""" + + +class OAuthError(GoogleAuthError): + """Used to indicate an error occurred during an OAuth related HTTP + request.""" diff --git a/google/auth/external_account.py b/google/auth/external_account.py new file mode 100644 index 000000000..0429ee08f --- /dev/null +++ b/google/auth/external_account.py @@ -0,0 +1,305 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""External Account Credentials. + +This module provides credentials that exchange workload identity pool external +credentials for Google access tokens. This facilitates accessing Google Cloud +Platform resources from on-prem and non-Google Cloud platforms (e.g. AWS, +Microsoft Azure, OIDC identity providers), using native credentials retrieved +from the current environment without the need to copy, save and manage +long-lived service account credentials. + +Specifically, this is intended to use access tokens acquired using the GCP STS +token exchange endpoint following the `OAuth 2.0 Token Exchange`_ spec. + +.. _OAuth 2.0 Token Exchange: https://tools.ietf.org/html/rfc8693 +""" + +import abc +import datetime +import json + +import six + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.auth import impersonated_credentials +from google.oauth2 import sts +from google.oauth2 import utils + +# The token exchange grant_type used for exchanging credentials. +_STS_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" +# The token exchange requested_token_type. This is always an access_token. +_STS_REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +# Cloud resource manager URL used to retrieve project information. +_CLOUD_RESOURCE_MANAGER = "https://cloudresourcemanager.googleapis.com/v1/projects/" + + +@six.add_metaclass(abc.ABCMeta) +class Credentials(credentials.Scoped, credentials.CredentialsWithQuotaProject): + """Base class for all external account credentials. + + This is used to instantiate Credentials for exchanging external account + credentials for Google access token and authorizing requests to Google APIs. + The base class implements the common logic for exchanging external account + credentials for Google access tokens. + """ + + def __init__( + self, + audience, + subject_token_type, + token_url, + credential_source, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + ): + """Instantiates an external account credentials object. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type. + token_url (str): The STS endpoint URL. + credential_source (Mapping): The credential source dictionary. + service_account_impersonation_url (Optional[str]): The optional service account + impersonation generateAccessToken URL. + client_id (Optional[str]): The optional client ID. + client_secret (Optional[str]): The optional client secret. + quota_project_id (Optional[str]): The optional quota project ID. + scopes (Optional[Sequence[str]]): Optional scopes to request during the + authorization grant. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + Raises: + google.auth.exceptions.RefreshError: If the generateAccessToken + endpoint returned an error. + """ + super(Credentials, self).__init__() + self._audience = audience + self._subject_token_type = subject_token_type + self._token_url = token_url + self._credential_source = credential_source + self._service_account_impersonation_url = service_account_impersonation_url + self._client_id = client_id + self._client_secret = client_secret + self._quota_project_id = quota_project_id + self._scopes = scopes + self._default_scopes = default_scopes + + if self._client_id: + self._client_auth = utils.ClientAuthentication( + utils.ClientAuthType.basic, self._client_id, self._client_secret + ) + else: + self._client_auth = None + self._sts_client = sts.Client(self._token_url, self._client_auth) + + if self._service_account_impersonation_url: + self._impersonated_credentials = self._initialize_impersonated_credentials() + else: + self._impersonated_credentials = None + self._project_id = None + + @property + def requires_scopes(self): + """Checks if the credentials requires scopes. + + Returns: + bool: True if there are no scopes set otherwise False. + """ + return not self._scopes and not self._default_scopes + + @property + def project_number(self): + """Optional[str]: The project number corresponding to the workload identity pool.""" + + # STS audience pattern: + # //iam.googleapis.com/projects/$PROJECT_NUMBER/locations/... + components = self._audience.split("/") + try: + project_index = components.index("projects") + if project_index + 1 < len(components): + return components[project_index + 1] or None + except ValueError: + return None + + @_helpers.copy_docstring(credentials.Scoped) + def with_scopes(self, scopes, default_scopes=None): + return self.__class__( + audience=self._audience, + subject_token_type=self._subject_token_type, + token_url=self._token_url, + credential_source=self._credential_source, + service_account_impersonation_url=self._service_account_impersonation_url, + client_id=self._client_id, + client_secret=self._client_secret, + quota_project_id=self._quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + ) + + @abc.abstractmethod + def retrieve_subject_token(self, request): + """Retrieves the subject token using the credential_source object. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + str: The retrieved subject token. + """ + # pylint: disable=missing-raises-doc + # (pylint doesn't recognize that this is abstract) + raise NotImplementedError("retrieve_subject_token must be implemented") + + def get_project_id(self, request): + """Retrieves the project ID corresponding to the workload identity pool. + + When not determinable, None is returned. + + This is introduced to support the current pattern of using the Auth library: + + credentials, project_id = google.auth.default() + + The resource may not have permission (resourcemanager.projects.get) to + call this API or the required scopes may not be selected: + https://cloud.google.com/resource-manager/reference/rest/v1/projects/get#authorization-scopes + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + Returns: + Optional[str]: The project ID corresponding to the workload identity pool + if determinable. + """ + if self._project_id: + # If already retrieved, return the cached project ID value. + return self._project_id + scopes = self._scopes if self._scopes is not None else self._default_scopes + # Scopes are required in order to retrieve a valid access token. + if self.project_number and scopes: + headers = {} + url = _CLOUD_RESOURCE_MANAGER + self.project_number + self.before_request(request, "GET", url, headers) + response = request(url=url, method="GET", headers=headers) + + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + response_data = json.loads(response_body) + + if response.status == 200: + # Cache result as this field is immutable. + self._project_id = response_data.get("projectId") + return self._project_id + + return None + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + scopes = self._scopes if self._scopes is not None else self._default_scopes + if self._impersonated_credentials: + self._impersonated_credentials.refresh(request) + self.token = self._impersonated_credentials.token + self.expiry = self._impersonated_credentials.expiry + else: + now = _helpers.utcnow() + response_data = self._sts_client.exchange_token( + request=request, + grant_type=_STS_GRANT_TYPE, + subject_token=self.retrieve_subject_token(request), + subject_token_type=self._subject_token_type, + audience=self._audience, + scopes=scopes, + requested_token_type=_STS_REQUESTED_TOKEN_TYPE, + ) + self.token = response_data.get("access_token") + lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) + self.expiry = now + lifetime + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + # Return copy of instance with the provided quota project ID. + return self.__class__( + audience=self._audience, + subject_token_type=self._subject_token_type, + token_url=self._token_url, + credential_source=self._credential_source, + service_account_impersonation_url=self._service_account_impersonation_url, + client_id=self._client_id, + client_secret=self._client_secret, + quota_project_id=quota_project_id, + scopes=self._scopes, + default_scopes=self._default_scopes, + ) + + def _initialize_impersonated_credentials(self): + """Generates an impersonated credentials. + + For more details, see `projects.serviceAccounts.generateAccessToken`_. + + .. _projects.serviceAccounts.generateAccessToken: https://cloud.google.com/iam/docs/reference/credentials/rest/v1/projects.serviceAccounts/generateAccessToken + + Returns: + impersonated_credentials.Credential: The impersonated credentials + object. + + Raises: + google.auth.exceptions.RefreshError: If the generateAccessToken + endpoint returned an error. + """ + # Return copy of instance with no service account impersonation. + source_credentials = self.__class__( + audience=self._audience, + subject_token_type=self._subject_token_type, + token_url=self._token_url, + credential_source=self._credential_source, + service_account_impersonation_url=None, + client_id=self._client_id, + client_secret=self._client_secret, + quota_project_id=self._quota_project_id, + scopes=self._scopes, + default_scopes=self._default_scopes, + ) + + # Determine target_principal. + start_index = self._service_account_impersonation_url.rfind("/") + end_index = self._service_account_impersonation_url.find(":generateAccessToken") + if start_index != -1 and end_index != -1 and start_index < end_index: + start_index = start_index + 1 + target_principal = self._service_account_impersonation_url[ + start_index:end_index + ] + else: + raise exceptions.RefreshError( + "Unable to determine target principal from service account impersonation URL." + ) + + scopes = self._scopes if self._scopes is not None else self._default_scopes + # Initialize and return impersonated credentials. + return impersonated_credentials.Credentials( + source_credentials=source_credentials, + target_principal=target_principal, + target_scopes=scopes, + quota_project_id=self._quota_project_id, + iam_endpoint_override=self._service_account_impersonation_url, + ) diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py new file mode 100644 index 000000000..536219955 --- /dev/null +++ b/google/auth/identity_pool.py @@ -0,0 +1,279 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Identity Pool Credentials. + +This module provides credentials to access Google Cloud resources from on-prem +or non-Google Cloud platforms which support external credentials (e.g. OIDC ID +tokens) retrieved from local file locations or local servers. This includes +Microsoft Azure and OIDC identity providers (e.g. K8s workloads registered with +Hub with Hub workload identity enabled). + +These credentials are recommended over the use of service account credentials +in on-prem/non-Google Cloud platforms as they do not involve the management of +long-live service account private keys. + +Identity Pool Credentials are initialized using external_account +arguments which are typically loaded from an external credentials file or +an external credentials URL. Unlike other Credentials that can be initialized +with a list of explicit arguments, secrets or credentials, external account +clients use the environment and hints/guidelines provided by the +external_account JSON file to retrieve credentials and exchange them for Google +access tokens. +""" + +try: + from collections.abc import Mapping +# Python 2.7 compatibility +except ImportError: # pragma: NO COVER + from collections import Mapping +import io +import json +import os + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import external_account + + +class Credentials(external_account.Credentials): + """External account credentials sourced from files and URLs.""" + + def __init__( + self, + audience, + subject_token_type, + token_url, + credential_source, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + ): + """Instantiates an external account credentials object from a file/URL. + + Args: + audience (str): The STS audience field. + subject_token_type (str): The subject token type. + token_url (str): The STS endpoint URL. + credential_source (Mapping): The credential source dictionary used to + provide instructions on how to retrieve external credential to be + exchanged for Google access tokens. + + Example credential_source for url-sourced credential:: + + { + "url": "http://www.example.com", + "format": { + "type": "json", + "subject_token_field_name": "access_token", + }, + "headers": {"foo": "bar"}, + } + + Example credential_source for file-sourced credential:: + + { + "file": "/path/to/token/file.txt" + } + + service_account_impersonation_url (Optional[str]): The optional service account + impersonation getAccessToken URL. + client_id (Optional[str]): The optional client ID. + client_secret (Optional[str]): The optional client secret. + quota_project_id (Optional[str]): The optional quota project ID. + scopes (Optional[Sequence[str]]): Optional scopes to request during the + authorization grant. + default_scopes (Optional[Sequence[str]]): Default scopes passed by a + Google client library. Use 'scopes' for user-defined scopes. + + Raises: + google.auth.exceptions.RefreshError: If an error is encountered during + access token retrieval logic. + ValueError: For invalid parameters. + + .. note:: Typically one of the helper constructors + :meth:`from_file` or + :meth:`from_info` are used instead of calling the constructor directly. + """ + + super(Credentials, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + service_account_impersonation_url=service_account_impersonation_url, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + ) + if not isinstance(credential_source, Mapping): + self._credential_source_file = None + self._credential_source_url = None + else: + self._credential_source_file = credential_source.get("file") + self._credential_source_url = credential_source.get("url") + self._credential_source_headers = credential_source.get("headers") + credential_source_format = credential_source.get("format", {}) + # Get credential_source format type. When not provided, this + # defaults to text. + self._credential_source_format_type = ( + credential_source_format.get("type") or "text" + ) + # environment_id is only supported in AWS or dedicated future external + # account credentials. + if "environment_id" in credential_source: + raise ValueError( + "Invalid Identity Pool credential_source field 'environment_id'" + ) + if self._credential_source_format_type not in ["text", "json"]: + raise ValueError( + "Invalid credential_source format '{}'".format( + self._credential_source_format_type + ) + ) + # For JSON types, get the required subject_token field name. + if self._credential_source_format_type == "json": + self._credential_source_field_name = credential_source_format.get( + "subject_token_field_name" + ) + if self._credential_source_field_name is None: + raise ValueError( + "Missing subject_token_field_name for JSON credential_source format" + ) + else: + self._credential_source_field_name = None + + if self._credential_source_file and self._credential_source_url: + raise ValueError( + "Ambiguous credential_source. 'file' is mutually exclusive with 'url'." + ) + if not self._credential_source_file and not self._credential_source_url: + raise ValueError( + "Missing credential_source. A 'file' or 'url' must be provided." + ) + + @_helpers.copy_docstring(external_account.Credentials) + def retrieve_subject_token(self, request): + return self._parse_token_data( + self._get_token_data(request), + self._credential_source_format_type, + self._credential_source_field_name, + ) + + def _get_token_data(self, request): + if self._credential_source_file: + return self._get_file_data(self._credential_source_file) + else: + return self._get_url_data( + request, self._credential_source_url, self._credential_source_headers + ) + + def _get_file_data(self, filename): + if not os.path.exists(filename): + raise exceptions.RefreshError("File '{}' was not found.".format(filename)) + + with io.open(filename, "r", encoding="utf-8") as file_obj: + return file_obj.read(), filename + + def _get_url_data(self, request, url, headers): + response = request(url=url, method="GET", headers=headers) + + # support both string and bytes type response.data + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + if response.status != 200: + raise exceptions.RefreshError( + "Unable to retrieve Identity Pool subject token", response_body + ) + + return response_body, url + + def _parse_token_data( + self, token_content, format_type="text", subject_token_field_name=None + ): + content, filename = token_content + if format_type == "text": + token = content + else: + try: + # Parse file content as JSON. + response_data = json.loads(content) + # Get the subject_token. + token = response_data[subject_token_field_name] + except (KeyError, ValueError): + raise exceptions.RefreshError( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + filename, subject_token_field_name + ) + ) + if not token: + raise exceptions.RefreshError( + "Missing subject_token in the credential_source file" + ) + return token + + @classmethod + def from_info(cls, info, **kwargs): + """Creates an Identity Pool Credentials instance from parsed external account info. + + Args: + info (Mapping[str, str]): The Identity Pool external account info in Google + format. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.identity_pool.Credentials: The constructed + credentials. + + Raises: + ValueError: For invalid parameters. + """ + return cls( + audience=info.get("audience"), + subject_token_type=info.get("subject_token_type"), + token_url=info.get("token_url"), + service_account_impersonation_url=info.get( + "service_account_impersonation_url" + ), + client_id=info.get("client_id"), + client_secret=info.get("client_secret"), + credential_source=info.get("credential_source"), + quota_project_id=info.get("quota_project_id"), + **kwargs + ) + + @classmethod + def from_file(cls, filename, **kwargs): + """Creates an IdentityPool Credentials instance from an external account json file. + + Args: + filename (str): The path to the IdentityPool external account json file. + kwargs: Additional arguments to pass to the constructor. + + Returns: + google.auth.identity_pool.Credentials: The constructed + credentials. + """ + with io.open(filename, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + return cls.from_info(data, **kwargs) diff --git a/google/auth/impersonated_credentials.py b/google/auth/impersonated_credentials.py index 4d158373a..b8a6c49a1 100644 --- a/google/auth/impersonated_credentials.py +++ b/google/auth/impersonated_credentials.py @@ -65,7 +65,9 @@ _DEFAULT_TOKEN_URI = "https://oauth2.googleapis.com/token" -def _make_iam_token_request(request, principal, headers, body): +def _make_iam_token_request( + request, principal, headers, body, iam_endpoint_override=None +): """Makes a request to the Google Cloud IAM service for an access token. Args: request (Request): The Request object to use. @@ -73,6 +75,9 @@ def _make_iam_token_request(request, principal, headers, body): headers (Mapping[str, str]): Map of headers to transmit. body (Mapping[str, str]): JSON Payload body for the iamcredentials API call. + iam_endpoint_override (Optiona[str]): The full IAM endpoint override + with the target_principal embedded. This is useful when supporting + impersonation with regional endpoints. Raises: google.auth.exceptions.TransportError: Raised if there is an underlying @@ -82,7 +87,7 @@ def _make_iam_token_request(request, principal, headers, body): `iamcredentials.googleapis.com` is not enabled or the `Service Account Token Creator` is not assigned """ - iam_endpoint = _IAM_ENDPOINT.format(principal) + iam_endpoint = iam_endpoint_override or _IAM_ENDPOINT.format(principal) body = json.dumps(body).encode("utf-8") @@ -185,6 +190,7 @@ def __init__( delegates=None, lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, quota_project_id=None, + iam_endpoint_override=None, ): """ Args: @@ -209,6 +215,9 @@ def __init__( quota_project_id (Optional[str]): The project ID used for quota and billing. This project may be different from the project used to create the credentials. + iam_endpoint_override (Optiona[str]): The full IAM endpoint override + with the target_principal embedded. This is useful when supporting + impersonation with regional endpoints. """ super(Credentials, self).__init__() @@ -226,6 +235,7 @@ def __init__( self.token = None self.expiry = _helpers.utcnow() self._quota_project_id = quota_project_id + self._iam_endpoint_override = iam_endpoint_override @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): @@ -260,6 +270,7 @@ def _update_token(self, request): principal=self._target_principal, headers=headers, body=body, + iam_endpoint_override=self._iam_endpoint_override, ) def sign_bytes(self, message): @@ -302,6 +313,7 @@ def with_quota_project(self, quota_project_id): delegates=self._delegates, lifetime=self._lifetime, quota_project_id=quota_project_id, + iam_endpoint_override=self._iam_endpoint_override, ) diff --git a/google/oauth2/sts.py b/google/oauth2/sts.py new file mode 100644 index 000000000..ae3c0146b --- /dev/null +++ b/google/oauth2/sts.py @@ -0,0 +1,155 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Token Exchange Spec. + +This module defines a token exchange utility based on the `OAuth 2.0 Token +Exchange`_ spec. This will be mainly used to exchange external credentials +for GCP access tokens in workload identity pools to access Google APIs. + +The implementation will support various types of client authentication as +allowed in the spec. + +A deviation on the spec will be for additional Google specific options that +cannot be easily mapped to parameters defined in the RFC. + +The returned dictionary response will be based on the `rfc8693 section 2.2.1`_ +spec JSON response. + +.. _OAuth 2.0 Token Exchange: https://tools.ietf.org/html/rfc8693 +.. _rfc8693 section 2.2.1: https://tools.ietf.org/html/rfc8693#section-2.2.1 +""" + +import json + +from six.moves import http_client +from six.moves import urllib + +from google.oauth2 import utils + + +_URLENCODED_HEADERS = {"Content-Type": "application/x-www-form-urlencoded"} + + +class Client(utils.OAuthClientAuthHandler): + """Implements the OAuth 2.0 token exchange spec based on + https://tools.ietf.org/html/rfc8693. + """ + + def __init__(self, token_exchange_endpoint, client_authentication=None): + """Initializes an STS client instance. + + Args: + token_exchange_endpoint (str): The token exchange endpoint. + client_authentication (Optional(google.oauth2.oauth2_utils.ClientAuthentication)): + The optional OAuth client authentication credentials if available. + """ + super(Client, self).__init__(client_authentication) + self._token_exchange_endpoint = token_exchange_endpoint + + def exchange_token( + self, + request, + grant_type, + subject_token, + subject_token_type, + resource=None, + audience=None, + scopes=None, + requested_token_type=None, + actor_token=None, + actor_token_type=None, + additional_options=None, + additional_headers=None, + ): + """Exchanges the provided token for another type of token based on the + rfc8693 spec. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + grant_type (str): The OAuth 2.0 token exchange grant type. + subject_token (str): The OAuth 2.0 token exchange subject token. + subject_token_type (str): The OAuth 2.0 token exchange subject token type. + resource (Optional[str]): The optional OAuth 2.0 token exchange resource field. + audience (Optional[str]): The optional OAuth 2.0 token exchange audience field. + scopes (Optional[Sequence[str]]): The optional list of scopes to use. + requested_token_type (Optional[str]): The optional OAuth 2.0 token exchange requested + token type. + actor_token (Optional[str]): The optional OAuth 2.0 token exchange actor token. + actor_token_type (Optional[str]): The optional OAuth 2.0 token exchange actor token type. + additional_options (Optional[Mapping[str, str]]): The optional additional + non-standard Google specific options. + additional_headers (Optional[Mapping[str, str]]): The optional additional + headers to pass to the token exchange endpoint. + + Returns: + Mapping[str, str]: The token exchange JSON-decoded response data containing + the requested token and its expiration time. + + Raises: + google.auth.exceptions.OAuthError: If the token endpoint returned + an error. + """ + # Initialize request headers. + headers = _URLENCODED_HEADERS.copy() + # Inject additional headers. + if additional_headers: + for k, v in dict(additional_headers).items(): + headers[k] = v + # Initialize request body. + request_body = { + "grant_type": grant_type, + "resource": resource, + "audience": audience, + "scope": " ".join(scopes or []), + "requested_token_type": requested_token_type, + "subject_token": subject_token, + "subject_token_type": subject_token_type, + "actor_token": actor_token, + "actor_token_type": actor_token_type, + "options": None, + } + # Add additional non-standard options. + if additional_options: + request_body["options"] = urllib.parse.quote(json.dumps(additional_options)) + # Remove empty fields in request body. + for k, v in dict(request_body).items(): + if v is None or v == "": + del request_body[k] + # Apply OAuth client authentication. + self.apply_client_authentication_options(headers, request_body) + + # Execute request. + response = request( + url=self._token_exchange_endpoint, + method="POST", + headers=headers, + body=urllib.parse.urlencode(request_body).encode("utf-8"), + ) + + response_body = ( + response.data.decode("utf-8") + if hasattr(response.data, "decode") + else response.data + ) + + # If non-200 response received, translate to OAuthError exception. + if response.status != http_client.OK: + utils.handle_error_response(response_body) + + response_data = json.loads(response_body) + + # Return successful response. + return response_data diff --git a/google/oauth2/utils.py b/google/oauth2/utils.py new file mode 100644 index 000000000..efda7968d --- /dev/null +++ b/google/oauth2/utils.py @@ -0,0 +1,171 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Utilities. + +This module provides implementations for various OAuth 2.0 utilities. +This includes `OAuth error handling`_ and +`Client authentication for OAuth flows`_. + +OAuth error handling +-------------------- +This will define interfaces for handling OAuth related error responses as +stated in `RFC 6749 section 5.2`_. +This will include a common function to convert these HTTP error responses to a +:class:`google.auth.exceptions.OAuthError` exception. + + +Client authentication for OAuth flows +------------------------------------- +We introduce an interface for defining client authentication credentials based +on `RFC 6749 section 2.3.1`_. This will expose the following +capabilities: + + * Ability to support basic authentication via request header. + * Ability to support bearer token authentication via request header. + * Ability to support client ID / secret authentication via request body. + +.. _RFC 6749 section 2.3.1: https://tools.ietf.org/html/rfc6749#section-2.3.1 +.. _RFC 6749 section 5.2: https://tools.ietf.org/html/rfc6749#section-5.2 +""" + +import abc +import base64 +import enum +import json + +import six + +from google.auth import exceptions + + +# OAuth client authentication based on +# https://tools.ietf.org/html/rfc6749#section-2.3. +class ClientAuthType(enum.Enum): + basic = 1 + request_body = 2 + + +class ClientAuthentication(object): + """Defines the client authentication credentials for basic and request-body + types based on https://tools.ietf.org/html/rfc6749#section-2.3.1. + """ + + def __init__(self, client_auth_type, client_id, client_secret=None): + """Instantiates a client authentication object containing the client ID + and secret credentials for basic and response-body auth. + + Args: + client_auth_type (google.oauth2.oauth_utils.ClientAuthType): The + client authentication type. + client_id (str): The client ID. + client_secret (Optional[str]): The client secret. + """ + self.client_auth_type = client_auth_type + self.client_id = client_id + self.client_secret = client_secret + + +@six.add_metaclass(abc.ABCMeta) +class OAuthClientAuthHandler(object): + """Abstract class for handling client authentication in OAuth-based + operations. + """ + + def __init__(self, client_authentication=None): + """Instantiates an OAuth client authentication handler. + + Args: + client_authentication (Optional[google.oauth2.utils.ClientAuthentication]): + The OAuth client authentication credentials if available. + """ + super(OAuthClientAuthHandler, self).__init__() + self._client_authentication = client_authentication + + def apply_client_authentication_options( + self, headers, request_body=None, bearer_token=None + ): + """Applies client authentication on the OAuth request's headers or POST + body. + + Args: + headers (Mapping[str, str]): The HTTP request header. + request_body (Optional[Mapping[str, str]): The HTTP request body + dictionary. For requests that do not support request body, this + is None and will be ignored. + bearer_token (Optional[str]): The optional bearer token. + """ + # Inject authenticated header. + self._inject_authenticated_headers(headers, bearer_token) + # Inject authenticated request body. + if bearer_token is None: + self._inject_authenticated_request_body(request_body) + + def _inject_authenticated_headers(self, headers, bearer_token=None): + if bearer_token is not None: + headers["Authorization"] = "Bearer %s" % bearer_token + elif ( + self._client_authentication is not None + and self._client_authentication.client_auth_type is ClientAuthType.basic + ): + username = self._client_authentication.client_id + password = self._client_authentication.client_secret or "" + + credentials = base64.b64encode( + ("%s:%s" % (username, password)).encode() + ).decode() + headers["Authorization"] = "Basic %s" % credentials + + def _inject_authenticated_request_body(self, request_body): + if ( + self._client_authentication is not None + and self._client_authentication.client_auth_type + is ClientAuthType.request_body + ): + if request_body is None: + raise exceptions.OAuthError( + "HTTP request does not support request-body" + ) + else: + request_body["client_id"] = self._client_authentication.client_id + request_body["client_secret"] = ( + self._client_authentication.client_secret or "" + ) + + +def handle_error_response(response_body): + """Translates an error response from an OAuth operation into an + OAuthError exception. + + Args: + response_body (str): The decoded response data. + + Raises: + google.auth.exceptions.OAuthError + """ + try: + error_components = [] + error_data = json.loads(response_body) + + error_components.append("Error code {}".format(error_data["error"])) + if "error_description" in error_data: + error_components.append(": {}".format(error_data["error_description"])) + if "error_uri" in error_data: + error_components.append(" - {}".format(error_data["error_uri"])) + error_details = "".join(error_components) + # If no details could be extracted, use the response data. + except (KeyError, ValueError): + error_details = response_body + + raise exceptions.OAuthError(error_details, response_body) diff --git a/noxfile.py b/noxfile.py index adce2527c..fa88d24b2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -64,9 +64,7 @@ def lint(session): @nox.session(python="3.6") def blacken(session): """Run black. - Format code to uniform standard. - This currently uses Python 3.6 due to the automated Kokoro run of synthtool. That run uses an image that doesn't have 3.6 installed. Before updating this check the state of the `gcp_ubuntu_config` we use for that Kokoro run. diff --git a/system_tests/system_tests_sync/secrets.tar.enc b/system_tests/system_tests_sync/secrets.tar.enc new file mode 100644 index 000000000..29e06923f Binary files /dev/null and b/system_tests/system_tests_sync/secrets.tar.enc differ diff --git a/tests/data/external_subject_token.json b/tests/data/external_subject_token.json new file mode 100644 index 000000000..a47ec3412 --- /dev/null +++ b/tests/data/external_subject_token.json @@ -0,0 +1,3 @@ +{ + "access_token": "HEADER.SIMULATED_JWT_PAYLOAD.SIGNATURE" +} \ No newline at end of file diff --git a/tests/data/external_subject_token.txt b/tests/data/external_subject_token.txt new file mode 100644 index 000000000..c668d8f71 --- /dev/null +++ b/tests/data/external_subject_token.txt @@ -0,0 +1 @@ +HEADER.SIMULATED_JWT_PAYLOAD.SIGNATURE \ No newline at end of file diff --git a/tests/oauth2/test_sts.py b/tests/oauth2/test_sts.py new file mode 100644 index 000000000..8792bd6bc --- /dev/null +++ b/tests/oauth2/test_sts.py @@ -0,0 +1,395 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import mock +import pytest +from six.moves import http_client +from six.moves import urllib + +from google.auth import exceptions +from google.auth import transport +from google.oauth2 import sts +from google.oauth2 import utils + +CLIENT_ID = "username" +CLIENT_SECRET = "password" +# Base64 encoding of "username:password" +BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" + + +class TestStsClient(object): + GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" + RESOURCE = "https://api.example.com/" + AUDIENCE = "urn:example:cooperation-context" + SCOPES = ["scope1", "scope2"] + REQUESTED_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" + SUBJECT_TOKEN = "HEADER.SUBJECT_TOKEN_PAYLOAD.SIGNATURE" + SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" + ACTOR_TOKEN = "HEADER.ACTOR_TOKEN_PAYLOAD.SIGNATURE" + ACTOR_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" + TOKEN_EXCHANGE_ENDPOINT = "https://example.com/token.oauth2" + ADDON_HEADERS = {"x-client-version": "0.1.2"} + ADDON_OPTIONS = {"additional": {"non-standard": ["options"], "other": "some-value"}} + SUCCESS_RESPONSE = { + "access_token": "ACCESS_TOKEN", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "scope1 scope2", + } + ERROR_RESPONSE = { + "error": "invalid_request", + "error_description": "Invalid subject token", + "error_uri": "https://tools.ietf.org/html/rfc6749", + } + CLIENT_AUTH_BASIC = utils.ClientAuthentication( + utils.ClientAuthType.basic, CLIENT_ID, CLIENT_SECRET + ) + CLIENT_AUTH_REQUEST_BODY = utils.ClientAuthentication( + utils.ClientAuthType.request_body, CLIENT_ID, CLIENT_SECRET + ) + + @classmethod + def make_client(cls, client_auth=None): + return sts.Client(cls.TOKEN_EXCHANGE_ENDPOINT, client_auth) + + @classmethod + def make_mock_request(cls, data, status=http_client.OK): + response = mock.create_autospec(transport.Response, instance=True) + response.status = status + response.data = json.dumps(data).encode("utf-8") + + request = mock.create_autospec(transport.Request) + request.return_value = response + + return request + + @classmethod + def assert_request_kwargs(cls, request_kwargs, headers, request_data): + """Asserts the request was called with the expected parameters. + """ + assert request_kwargs["url"] == cls.TOKEN_EXCHANGE_ENDPOINT + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_tuples = urllib.parse.parse_qsl(request_kwargs["body"]) + for (k, v) in body_tuples: + assert v.decode("utf-8") == request_data[k.decode("utf-8")] + assert len(body_tuples) == len(request_data.keys()) + + def test_exchange_token_full_success_without_auth(self): + """Test token exchange success without client authentication using full + parameters. + """ + client = self.make_client() + headers = self.ADDON_HEADERS.copy() + headers["Content-Type"] = "application/x-www-form-urlencoded" + request_data = { + "grant_type": self.GRANT_TYPE, + "resource": self.RESOURCE, + "audience": self.AUDIENCE, + "scope": " ".join(self.SCOPES), + "requested_token_type": self.REQUESTED_TOKEN_TYPE, + "subject_token": self.SUBJECT_TOKEN, + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "actor_token": self.ACTOR_TOKEN, + "actor_token_type": self.ACTOR_TOKEN_TYPE, + "options": urllib.parse.quote(json.dumps(self.ADDON_OPTIONS)), + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client.exchange_token( + request, + self.GRANT_TYPE, + self.SUBJECT_TOKEN, + self.SUBJECT_TOKEN_TYPE, + self.RESOURCE, + self.AUDIENCE, + self.SCOPES, + self.REQUESTED_TOKEN_TYPE, + self.ACTOR_TOKEN, + self.ACTOR_TOKEN_TYPE, + self.ADDON_OPTIONS, + self.ADDON_HEADERS, + ) + + self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_exchange_token_partial_success_without_auth(self): + """Test token exchange success without client authentication using + partial (required only) parameters. + """ + client = self.make_client() + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": self.GRANT_TYPE, + "audience": self.AUDIENCE, + "requested_token_type": self.REQUESTED_TOKEN_TYPE, + "subject_token": self.SUBJECT_TOKEN, + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client.exchange_token( + request, + grant_type=self.GRANT_TYPE, + subject_token=self.SUBJECT_TOKEN, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + audience=self.AUDIENCE, + requested_token_type=self.REQUESTED_TOKEN_TYPE, + ) + + self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_exchange_token_non200_without_auth(self): + """Test token exchange without client auth responding with non-200 status. + """ + client = self.make_client() + request = self.make_mock_request( + status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + client.exchange_token( + request, + self.GRANT_TYPE, + self.SUBJECT_TOKEN, + self.SUBJECT_TOKEN_TYPE, + self.RESOURCE, + self.AUDIENCE, + self.SCOPES, + self.REQUESTED_TOKEN_TYPE, + self.ACTOR_TOKEN, + self.ACTOR_TOKEN_TYPE, + self.ADDON_OPTIONS, + self.ADDON_HEADERS, + ) + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) + + def test_exchange_token_full_success_with_basic_auth(self): + """Test token exchange success with basic client authentication using full + parameters. + """ + client = self.make_client(self.CLIENT_AUTH_BASIC) + headers = self.ADDON_HEADERS.copy() + headers["Content-Type"] = "application/x-www-form-urlencoded" + headers["Authorization"] = "Basic {}".format(BASIC_AUTH_ENCODING) + request_data = { + "grant_type": self.GRANT_TYPE, + "resource": self.RESOURCE, + "audience": self.AUDIENCE, + "scope": " ".join(self.SCOPES), + "requested_token_type": self.REQUESTED_TOKEN_TYPE, + "subject_token": self.SUBJECT_TOKEN, + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "actor_token": self.ACTOR_TOKEN, + "actor_token_type": self.ACTOR_TOKEN_TYPE, + "options": urllib.parse.quote(json.dumps(self.ADDON_OPTIONS)), + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client.exchange_token( + request, + self.GRANT_TYPE, + self.SUBJECT_TOKEN, + self.SUBJECT_TOKEN_TYPE, + self.RESOURCE, + self.AUDIENCE, + self.SCOPES, + self.REQUESTED_TOKEN_TYPE, + self.ACTOR_TOKEN, + self.ACTOR_TOKEN_TYPE, + self.ADDON_OPTIONS, + self.ADDON_HEADERS, + ) + + self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_exchange_token_partial_success_with_basic_auth(self): + """Test token exchange success with basic client authentication using + partial (required only) parameters. + """ + client = self.make_client(self.CLIENT_AUTH_BASIC) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + request_data = { + "grant_type": self.GRANT_TYPE, + "audience": self.AUDIENCE, + "requested_token_type": self.REQUESTED_TOKEN_TYPE, + "subject_token": self.SUBJECT_TOKEN, + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client.exchange_token( + request, + grant_type=self.GRANT_TYPE, + subject_token=self.SUBJECT_TOKEN, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + audience=self.AUDIENCE, + requested_token_type=self.REQUESTED_TOKEN_TYPE, + ) + + self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_exchange_token_non200_with_basic_auth(self): + """Test token exchange with basic client auth responding with non-200 + status. + """ + client = self.make_client(self.CLIENT_AUTH_BASIC) + request = self.make_mock_request( + status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + client.exchange_token( + request, + self.GRANT_TYPE, + self.SUBJECT_TOKEN, + self.SUBJECT_TOKEN_TYPE, + self.RESOURCE, + self.AUDIENCE, + self.SCOPES, + self.REQUESTED_TOKEN_TYPE, + self.ACTOR_TOKEN, + self.ACTOR_TOKEN_TYPE, + self.ADDON_OPTIONS, + self.ADDON_HEADERS, + ) + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) + + def test_exchange_token_full_success_with_reqbody_auth(self): + """Test token exchange success with request body client authenticaiton + using full parameters. + """ + client = self.make_client(self.CLIENT_AUTH_REQUEST_BODY) + headers = self.ADDON_HEADERS.copy() + headers["Content-Type"] = "application/x-www-form-urlencoded" + request_data = { + "grant_type": self.GRANT_TYPE, + "resource": self.RESOURCE, + "audience": self.AUDIENCE, + "scope": " ".join(self.SCOPES), + "requested_token_type": self.REQUESTED_TOKEN_TYPE, + "subject_token": self.SUBJECT_TOKEN, + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "actor_token": self.ACTOR_TOKEN, + "actor_token_type": self.ACTOR_TOKEN_TYPE, + "options": urllib.parse.quote(json.dumps(self.ADDON_OPTIONS)), + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client.exchange_token( + request, + self.GRANT_TYPE, + self.SUBJECT_TOKEN, + self.SUBJECT_TOKEN_TYPE, + self.RESOURCE, + self.AUDIENCE, + self.SCOPES, + self.REQUESTED_TOKEN_TYPE, + self.ACTOR_TOKEN, + self.ACTOR_TOKEN_TYPE, + self.ADDON_OPTIONS, + self.ADDON_HEADERS, + ) + + self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_exchange_token_partial_success_with_reqbody_auth(self): + """Test token exchange success with request body client authentication + using partial (required only) parameters. + """ + client = self.make_client(self.CLIENT_AUTH_REQUEST_BODY) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": self.GRANT_TYPE, + "audience": self.AUDIENCE, + "requested_token_type": self.REQUESTED_TOKEN_TYPE, + "subject_token": self.SUBJECT_TOKEN, + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + + response = client.exchange_token( + request, + grant_type=self.GRANT_TYPE, + subject_token=self.SUBJECT_TOKEN, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + audience=self.AUDIENCE, + requested_token_type=self.REQUESTED_TOKEN_TYPE, + ) + + self.assert_request_kwargs(request.call_args.kwargs, headers, request_data) + assert response == self.SUCCESS_RESPONSE + + def test_exchange_token_non200_with_reqbody_auth(self): + """Test token exchange with POST request body client auth responding + with non-200 status. + """ + client = self.make_client(self.CLIENT_AUTH_REQUEST_BODY) + request = self.make_mock_request( + status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + client.exchange_token( + request, + self.GRANT_TYPE, + self.SUBJECT_TOKEN, + self.SUBJECT_TOKEN_TYPE, + self.RESOURCE, + self.AUDIENCE, + self.SCOPES, + self.REQUESTED_TOKEN_TYPE, + self.ACTOR_TOKEN, + self.ACTOR_TOKEN_TYPE, + self.ADDON_OPTIONS, + self.ADDON_HEADERS, + ) + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) diff --git a/tests/oauth2/test_utils.py b/tests/oauth2/test_utils.py new file mode 100644 index 000000000..6de9ff533 --- /dev/null +++ b/tests/oauth2/test_utils.py @@ -0,0 +1,264 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest + +from google.auth import exceptions +from google.oauth2 import utils + + +CLIENT_ID = "username" +CLIENT_SECRET = "password" +# Base64 encoding of "username:password" +BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" +# Base64 encoding of "username:" +BASIC_AUTH_ENCODING_SECRETLESS = "dXNlcm5hbWU6" + + +class AuthHandler(utils.OAuthClientAuthHandler): + def __init__(self, client_auth=None): + super(AuthHandler, self).__init__(client_auth) + + def apply_client_authentication_options( + self, headers, request_body=None, bearer_token=None + ): + return super(AuthHandler, self).apply_client_authentication_options( + headers, request_body, bearer_token + ) + + +class TestClientAuthentication(object): + @classmethod + def make_client_auth(cls, client_secret=None): + return utils.ClientAuthentication( + utils.ClientAuthType.basic, CLIENT_ID, client_secret + ) + + def test_initialization_with_client_secret(self): + client_auth = self.make_client_auth(CLIENT_SECRET) + + assert client_auth.client_auth_type == utils.ClientAuthType.basic + assert client_auth.client_id == CLIENT_ID + assert client_auth.client_secret == CLIENT_SECRET + + def test_initialization_no_client_secret(self): + client_auth = self.make_client_auth() + + assert client_auth.client_auth_type == utils.ClientAuthType.basic + assert client_auth.client_id == CLIENT_ID + assert client_auth.client_secret is None + + +class TestOAuthClientAuthHandler(object): + CLIENT_AUTH_BASIC = utils.ClientAuthentication( + utils.ClientAuthType.basic, CLIENT_ID, CLIENT_SECRET + ) + CLIENT_AUTH_BASIC_SECRETLESS = utils.ClientAuthentication( + utils.ClientAuthType.basic, CLIENT_ID + ) + CLIENT_AUTH_REQUEST_BODY = utils.ClientAuthentication( + utils.ClientAuthType.request_body, CLIENT_ID, CLIENT_SECRET + ) + CLIENT_AUTH_REQUEST_BODY_SECRETLESS = utils.ClientAuthentication( + utils.ClientAuthType.request_body, CLIENT_ID + ) + + @classmethod + def make_oauth_client_auth_handler(cls, client_auth=None): + return AuthHandler(client_auth) + + def test_apply_client_authentication_options_none(self): + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler() + + auth_handler.apply_client_authentication_options(headers, request_body) + + assert headers == {"Content-Type": "application/json"} + assert request_body == {"foo": "bar"} + + def test_apply_client_authentication_options_basic(self): + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler(self.CLIENT_AUTH_BASIC) + + auth_handler.apply_client_authentication_options(headers, request_body) + + assert headers == { + "Content-Type": "application/json", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + assert request_body == {"foo": "bar"} + + def test_apply_client_authentication_options_basic_nosecret(self): + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler( + self.CLIENT_AUTH_BASIC_SECRETLESS + ) + + auth_handler.apply_client_authentication_options(headers, request_body) + + assert headers == { + "Content-Type": "application/json", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING_SECRETLESS), + } + assert request_body == {"foo": "bar"} + + def test_apply_client_authentication_options_request_body(self): + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler( + self.CLIENT_AUTH_REQUEST_BODY + ) + + auth_handler.apply_client_authentication_options(headers, request_body) + + assert headers == {"Content-Type": "application/json"} + assert request_body == { + "foo": "bar", + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + } + + def test_apply_client_authentication_options_request_body_nosecret(self): + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler( + self.CLIENT_AUTH_REQUEST_BODY_SECRETLESS + ) + + auth_handler.apply_client_authentication_options(headers, request_body) + + assert headers == {"Content-Type": "application/json"} + assert request_body == { + "foo": "bar", + "client_id": CLIENT_ID, + "client_secret": "", + } + + def test_apply_client_authentication_options_request_body_no_body(self): + headers = {"Content-Type": "application/json"} + auth_handler = self.make_oauth_client_auth_handler( + self.CLIENT_AUTH_REQUEST_BODY + ) + + with pytest.raises(exceptions.OAuthError) as excinfo: + auth_handler.apply_client_authentication_options(headers) + + assert excinfo.match(r"HTTP request does not support request-body") + + def test_apply_client_authentication_options_bearer_token(self): + bearer_token = "ACCESS_TOKEN" + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler() + + auth_handler.apply_client_authentication_options( + headers, request_body, bearer_token + ) + + assert headers == { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(bearer_token), + } + assert request_body == {"foo": "bar"} + + def test_apply_client_authentication_options_bearer_and_basic(self): + bearer_token = "ACCESS_TOKEN" + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler(self.CLIENT_AUTH_BASIC) + + auth_handler.apply_client_authentication_options( + headers, request_body, bearer_token + ) + + # Bearer token should have higher priority. + assert headers == { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(bearer_token), + } + assert request_body == {"foo": "bar"} + + def test_apply_client_authentication_options_bearer_and_request_body(self): + bearer_token = "ACCESS_TOKEN" + headers = {"Content-Type": "application/json"} + request_body = {"foo": "bar"} + auth_handler = self.make_oauth_client_auth_handler( + self.CLIENT_AUTH_REQUEST_BODY + ) + + auth_handler.apply_client_authentication_options( + headers, request_body, bearer_token + ) + + # Bearer token should have higher priority. + assert headers == { + "Content-Type": "application/json", + "Authorization": "Bearer {}".format(bearer_token), + } + assert request_body == {"foo": "bar"} + + +def test__handle_error_response_code_only(): + error_resp = {"error": "unsupported_grant_type"} + response_data = json.dumps(error_resp) + + with pytest.raises(exceptions.OAuthError) as excinfo: + utils.handle_error_response(response_data) + + assert excinfo.match(r"Error code unsupported_grant_type") + + +def test__handle_error_response_code_description(): + error_resp = { + "error": "unsupported_grant_type", + "error_description": "The provided grant_type is unsupported", + } + response_data = json.dumps(error_resp) + + with pytest.raises(exceptions.OAuthError) as excinfo: + utils.handle_error_response(response_data) + + assert excinfo.match( + r"Error code unsupported_grant_type: The provided grant_type is unsupported" + ) + + +def test__handle_error_response_code_description_uri(): + error_resp = { + "error": "unsupported_grant_type", + "error_description": "The provided grant_type is unsupported", + "error_uri": "https://tools.ietf.org/html/rfc6749", + } + response_data = json.dumps(error_resp) + + with pytest.raises(exceptions.OAuthError) as excinfo: + utils.handle_error_response(response_data) + + assert excinfo.match( + r"Error code unsupported_grant_type: The provided grant_type is unsupported - https://tools.ietf.org/html/rfc6749" + ) + + +def test__handle_error_response_non_json(): + response_data = "Oops, something wrong happened" + + with pytest.raises(exceptions.OAuthError) as excinfo: + utils.handle_error_response(response_data) + + assert excinfo.match(r"Oops, something wrong happened") diff --git a/tests/test__default.py b/tests/test__default.py index 74511f9e5..ef6cb78d2 100644 --- a/tests/test__default.py +++ b/tests/test__default.py @@ -20,10 +20,13 @@ from google.auth import _default from google.auth import app_engine +from google.auth import aws from google.auth import compute_engine from google.auth import credentials from google.auth import environment_vars from google.auth import exceptions +from google.auth import external_account +from google.auth import identity_pool from google.oauth2 import service_account import google.oauth2.credentials @@ -49,6 +52,34 @@ with open(SERVICE_ACCOUNT_FILE) as fh: SERVICE_ACCOUNT_FILE_DATA = json.load(fh) +SUBJECT_TOKEN_TEXT_FILE = os.path.join(DATA_DIR, "external_subject_token.txt") +TOKEN_URL = "https://sts.googleapis.com/v1/token" +AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID" +REGION_URL = "http://169.254.169.254/latest/meta-data/placement/availability-zone" +SECURITY_CREDS_URL = "http://169.254.169.254/latest/meta-data/iam/security-credentials" +CRED_VERIFICATION_URL = ( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" +) +IDENTITY_POOL_DATA = { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "token_url": TOKEN_URL, + "credential_source": {"file": SUBJECT_TOKEN_TEXT_FILE}, +} +AWS_DATA = { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request", + "token_url": TOKEN_URL, + "credential_source": { + "environment_id": "aws1", + "region_url": REGION_URL, + "url": SECURITY_CREDS_URL, + "regional_cred_verification_url": CRED_VERIFICATION_URL, + }, +} + MOCK_CREDENTIALS = mock.Mock(spec=credentials.CredentialsWithQuotaProject) MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS @@ -57,6 +88,12 @@ return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), autospec=True, ) +EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH = mock.patch.object( + external_account.Credentials, + "get_project_id", + return_value=mock.sentinel.project_id, + autospec=True, +) def test_load_credentials_from_missing_file(): @@ -185,6 +222,92 @@ def test_load_credentials_from_file_service_account_bad_format(tmpdir): assert excinfo.match(r"missing fields") +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_load_credentials_from_file_external_account_identity_pool( + get_project_id, tmpdir +): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(IDENTITY_POOL_DATA)) + credentials, project_id = _default.load_credentials_from_file(str(config_file)) + + assert isinstance(credentials, identity_pool.Credentials) + assert project_id is mock.sentinel.project_id + assert get_project_id.called + + +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_load_credentials_from_file_external_account_aws(get_project_id, tmpdir): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(AWS_DATA)) + credentials, project_id = _default.load_credentials_from_file(str(config_file)) + + assert isinstance(credentials, aws.Credentials) + assert project_id is mock.sentinel.project_id + assert get_project_id.called + + +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_load_credentials_from_file_external_account_with_user_and_default_scopes( + get_project_id, tmpdir +): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(IDENTITY_POOL_DATA)) + credentials, project_id = _default.load_credentials_from_file( + str(config_file), + scopes=["https://www.google.com/calendar/feeds"], + default_scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + assert isinstance(credentials, identity_pool.Credentials) + assert project_id is mock.sentinel.project_id + assert credentials.scopes == ["https://www.google.com/calendar/feeds"] + assert credentials.default_scopes == [ + "https://www.googleapis.com/auth/cloud-platform" + ] + + +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_load_credentials_from_file_external_account_with_quota_project( + get_project_id, tmpdir +): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(IDENTITY_POOL_DATA)) + credentials, project_id = _default.load_credentials_from_file( + str(config_file), quota_project_id="project-foo" + ) + + assert isinstance(credentials, identity_pool.Credentials) + assert project_id is mock.sentinel.project_id + assert credentials.quota_project_id == "project-foo" + + +def test_load_credentials_from_file_external_account_bad_format(tmpdir): + filename = tmpdir.join("external_account_bad.json") + filename.write(json.dumps({"type": "external_account"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(filename)) + + assert excinfo.match( + "Failed to load external account credentials from {}".format(str(filename)) + ) + + +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_load_credentials_from_file_external_account_explicit_request( + get_project_id, tmpdir +): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(IDENTITY_POOL_DATA)) + credentials, project_id = _default.load_credentials_from_file( + str(config_file), request=mock.sentinel.request + ) + + assert isinstance(credentials, identity_pool.Credentials) + assert project_id is mock.sentinel.project_id + get_project_id.assert_called_with(credentials, request=mock.sentinel.request) + + @mock.patch.dict(os.environ, {}, clear=True) def test__get_explicit_environ_credentials_no_env(): assert _default._get_explicit_environ_credentials() == (None, None) @@ -198,7 +321,34 @@ def test__get_explicit_environ_credentials(load, monkeypatch): assert credentials is MOCK_CREDENTIALS assert project_id is mock.sentinel.project_id - load.assert_called_with("filename") + load.assert_called_with( + "filename", + scopes=None, + default_scopes=None, + quota_project_id=None, + request=None, + ) + + +@LOAD_FILE_PATCH +def test__get_explicit_environ_credentials_with_scopes_and_request(load, monkeypatch): + scopes = ["one", "two"] + monkeypatch.setenv(environment_vars.CREDENTIALS, "filename") + + credentials, project_id = _default._get_explicit_environ_credentials( + request=mock.sentinel.request, scopes=scopes + ) + + assert credentials is MOCK_CREDENTIALS + assert project_id is mock.sentinel.project_id + # Request and scopes should be propagated. + load.assert_called_with( + "filename", + scopes=scopes, + default_scopes=None, + quota_project_id=None, + request=mock.sentinel.request, + ) @LOAD_FILE_PATCH @@ -503,3 +653,70 @@ def test_default_no_app_engine_compute_engine_module(unused_get): sys.modules["google.auth.compute_engine"] = None sys.modules["google.auth.app_engine"] = None assert _default.default() == (MOCK_CREDENTIALS, mock.sentinel.project_id) + + +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_default_environ_external_credentials(get_project_id, monkeypatch, tmpdir): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(IDENTITY_POOL_DATA)) + monkeypatch.setenv(environment_vars.CREDENTIALS, str(config_file)) + + credentials, project_id = _default.default() + + assert isinstance(credentials, identity_pool.Credentials) + assert project_id is mock.sentinel.project_id + + +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_default_environ_external_credentials_with_user_and_default_scopes_and_quota_project_id( + get_project_id, monkeypatch, tmpdir +): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(IDENTITY_POOL_DATA)) + monkeypatch.setenv(environment_vars.CREDENTIALS, str(config_file)) + + credentials, project_id = _default.default( + scopes=["https://www.google.com/calendar/feeds"], + default_scopes=["https://www.googleapis.com/auth/cloud-platform"], + quota_project_id="project-foo", + ) + + assert isinstance(credentials, identity_pool.Credentials) + assert project_id is mock.sentinel.project_id + assert credentials.quota_project_id == "project-foo" + assert credentials.scopes == ["https://www.google.com/calendar/feeds"] + assert credentials.default_scopes == [ + "https://www.googleapis.com/auth/cloud-platform" + ] + + +@EXTERNAL_ACCOUNT_GET_PROJECT_ID_PATCH +def test_default_environ_external_credentials_explicit_request( + get_project_id, monkeypatch, tmpdir +): + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(IDENTITY_POOL_DATA)) + monkeypatch.setenv(environment_vars.CREDENTIALS, str(config_file)) + + credentials, project_id = _default.default(request=mock.sentinel.request) + + assert isinstance(credentials, identity_pool.Credentials) + assert project_id is mock.sentinel.project_id + # default() will initialize new credentials via with_scopes_if_required + # and potentially with_quota_project. + # As a result the caller of get_project_id() will not match the returned + # credentials. + get_project_id.assert_called_with(mock.ANY, request=mock.sentinel.request) + + +def test_default_environ_external_credentials_bad_format(monkeypatch, tmpdir): + filename = tmpdir.join("external_account_bad.json") + filename.write(json.dumps({"type": "external_account"})) + monkeypatch.setenv(environment_vars.CREDENTIALS, str(filename)) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.default() + + assert excinfo.match( + "Failed to load external account credentials from {}".format(str(filename)) + ) diff --git a/tests/test_aws.py b/tests/test_aws.py new file mode 100644 index 000000000..9a8f98eec --- /dev/null +++ b/tests/test_aws.py @@ -0,0 +1,1434 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json + +import mock +import pytest +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import aws +from google.auth import environment_vars +from google.auth import exceptions +from google.auth import transport + + +CLIENT_ID = "username" +CLIENT_SECRET = "password" +# Base64 encoding of "username:password". +BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" +SERVICE_ACCOUNT_EMAIL = "service-1234@service-name.iam.gserviceaccount.com" +SERVICE_ACCOUNT_IMPERSONATION_URL = ( + "https://us-east1-iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL) +) +QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID" +SCOPES = ["scope1", "scope2"] +TOKEN_URL = "https://sts.googleapis.com/v1/token" +SUBJECT_TOKEN_TYPE = "urn:ietf:params:aws:token-type:aws4_request" +AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID" +REGION_URL = "http://169.254.169.254/latest/meta-data/placement/availability-zone" +SECURITY_CREDS_URL = "http://169.254.169.254/latest/meta-data/iam/security-credentials" +CRED_VERIFICATION_URL = ( + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15" +) +# Sample AWS security credentials to be used with tests that require a session token. +ACCESS_KEY_ID = "ASIARD4OQDT6A77FR3CL" +SECRET_ACCESS_KEY = "Y8AfSaucF37G4PpvfguKZ3/l7Id4uocLXxX0+VTx" +TOKEN = "IQoJb3JpZ2luX2VjEIz//////////wEaCXVzLWVhc3QtMiJGMEQCIH7MHX/Oy/OB8OlLQa9GrqU1B914+iMikqWQW7vPCKlgAiA/Lsv8Jcafn14owfxXn95FURZNKaaphj0ykpmS+Ki+CSq0AwhlEAAaDDA3NzA3MTM5MTk5NiIMx9sAeP1ovlMTMKLjKpEDwuJQg41/QUKx0laTZYjPlQvjwSqS3OB9P1KAXPWSLkliVMMqaHqelvMF/WO/glv3KwuTfQsavRNs3v5pcSEm4SPO3l7mCs7KrQUHwGP0neZhIKxEXy+Ls//1C/Bqt53NL+LSbaGv6RPHaX82laz2qElphg95aVLdYgIFY6JWV5fzyjgnhz0DQmy62/Vi8pNcM2/VnxeCQ8CC8dRDSt52ry2v+nc77vstuI9xV5k8mPtnaPoJDRANh0bjwY5Sdwkbp+mGRUJBAQRlNgHUJusefXQgVKBCiyJY4w3Csd8Bgj9IyDV+Azuy1jQqfFZWgP68LSz5bURyIjlWDQunO82stZ0BgplKKAa/KJHBPCp8Qi6i99uy7qh76FQAqgVTsnDuU6fGpHDcsDSGoCls2HgZjZFPeOj8mmRhFk1Xqvkbjuz8V1cJk54d3gIJvQt8gD2D6yJQZecnuGWd5K2e2HohvCc8Fc9kBl1300nUJPV+k4tr/A5R/0QfEKOZL1/k5lf1g9CREnrM8LVkGxCgdYMxLQow1uTL+QU67AHRRSp5PhhGX4Rek+01vdYSnJCMaPhSEgcLqDlQkhk6MPsyT91QMXcWmyO+cAZwUPwnRamFepuP4K8k2KVXs/LIJHLELwAZ0ekyaS7CptgOqS7uaSTFG3U+vzFZLEnGvWQ7y9IPNQZ+Dffgh4p3vF4J68y9049sI6Sr5d5wbKkcbm8hdCDHZcv4lnqohquPirLiFQ3q7B17V9krMPu3mz1cg4Ekgcrn/E09NTsxAqD8NcZ7C7ECom9r+X3zkDOxaajW6hu3Az8hGlyylDaMiFfRbBJpTIlxp7jfa7CxikNgNtEKLH9iCzvuSg2vhA==" +# To avoid json.dumps() differing behavior from one version to other, +# the JSON payload is hardcoded. +REQUEST_PARAMS = '{"KeySchema":[{"KeyType":"HASH","AttributeName":"Id"}],"TableName":"TestTable","AttributeDefinitions":[{"AttributeName":"Id","AttributeType":"S"}],"ProvisionedThroughput":{"WriteCapacityUnits":5,"ReadCapacityUnits":5}}' +# Each tuple contains the following entries: +# region, time, credentials, original_request, signed_request +TEST_FIXTURES = [ + # GET request (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with relative path (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-relative-relative.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-relative-relative.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com/foo/bar/../..", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/foo/bar/../..", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with /./ path (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-dot-slash.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-dot-slash.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com/./", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/./", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b27ccfbfa7df52a200ff74193ca6e32d4b48b8856fab7ebf1c595d0670a7e470", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with pointless dot path (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-pointless-dot.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-slash-pointless-dot.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com/./foo", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/./foo", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=910e4d6c9abafaf87898e1eb4c929135782ea25bb0279703146455745391e63a", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with utf8 path (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-utf8.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-utf8.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com/%E1%88%B4", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/%E1%88%B4", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=8d6634c189aa8c75c2e51e106b6b5121bed103fdb351f7d7d4381c738823af74", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with duplicate query key (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-key-case.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-key-case.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com/?foo=Zoo&foo=aha", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/?foo=Zoo&foo=aha", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=be7148d34ebccdc6423b19085378aa0bee970bdc61d144bd1a8c48c33079ab09", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with duplicate out of order query key (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-value.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-query-order-value.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com/?foo=b&foo=a", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/?foo=b&foo=a", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=feb926e49e382bec75c9d7dcb2a1b6dc8aa50ca43c25d2bc51143768c0875acc", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with utf8 query (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-ut8-query.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-vanilla-ut8-query.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "GET", + "url": "https://host.foo.com/?{}=bar".format( + urllib.parse.unquote("%E1%88%B4") + ), + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/?{}=bar".format( + urllib.parse.unquote("%E1%88%B4") + ), + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=6fb359e9a05394cc7074e0feb42573a2601abc0c869a953e8c5c12e4e01f1a8c", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # POST request with sorted headers (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-key-sort.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-key-sort.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "POST", + "url": "https://host.foo.com/", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT", "ZOO": "zoobar"}, + }, + { + "url": "https://host.foo.com/", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;zoo, Signature=b7a95a52518abbca0964a999a880429ab734f35ebbf1235bd79a5de87756dc4a", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + "ZOO": "zoobar", + }, + }, + ), + # POST request with upper case header value from AWS Python test harness. + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-value-case.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-header-value-case.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "POST", + "url": "https://host.foo.com/", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT", "zoo": "ZOOBAR"}, + }, + { + "url": "https://host.foo.com/", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;zoo, Signature=273313af9d0c265c531e11db70bbd653f3ba074c1009239e8559d3987039cad7", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + "zoo": "ZOOBAR", + }, + }, + ), + # POST request with header and no body (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-header-value-trim.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/get-header-value-trim.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "POST", + "url": "https://host.foo.com/", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT", "p": "phfft"}, + }, + { + "url": "https://host.foo.com/", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host;p, Signature=debf546796015d6f6ded8626f5ce98597c33b47b9164cf6b17b4642036fcb592", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + "p": "phfft", + }, + }, + ), + # POST request with body and no header (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-x-www-form-urlencoded.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-x-www-form-urlencoded.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "POST", + "url": "https://host.foo.com/", + "headers": { + "Content-Type": "application/x-www-form-urlencoded", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + "data": "foo=bar", + }, + { + "url": "https://host.foo.com/", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=content-type;date;host, Signature=5a15b22cf462f047318703b92e6f4f38884e4a7ab7b1d6426ca46a8bd1c26cbc", + "host": "host.foo.com", + "Content-Type": "application/x-www-form-urlencoded", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + "data": "foo=bar", + }, + ), + # POST request with querystring (AWS botocore tests). + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-vanilla-query.req + # https://github.com/boto/botocore/blob/879f8440a4e9ace5d3cf145ce8b3d5e5ffb892ef/tests/unit/auth/aws4_testsuite/post-vanilla-query.sreq + ( + "us-east-1", + "2011-09-09T23:36:00Z", + { + "access_key_id": "AKIDEXAMPLE", + "secret_access_key": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + }, + { + "method": "POST", + "url": "https://host.foo.com/?foo=bar", + "headers": {"date": "Mon, 09 Sep 2011 23:36:00 GMT"}, + }, + { + "url": "https://host.foo.com/?foo=bar", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20110909/us-east-1/host/aws4_request, SignedHeaders=date;host, Signature=b6e3b79003ce0743a491606ba1035a804593b0efb1e20a11cba83f8c25a57a92", + "host": "host.foo.com", + "date": "Mon, 09 Sep 2011 23:36:00 GMT", + }, + }, + ), + # GET request with session token credentials. + ( + "us-east-2", + "2020-08-11T06:55:22Z", + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + }, + { + "method": "GET", + "url": "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", + }, + { + "url": "https://ec2.us-east-2.amazonaws.com?Action=DescribeRegions&Version=2013-10-15", + "method": "GET", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=" + + ACCESS_KEY_ID + + "/20200811/us-east-2/ec2/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=631ea80cddfaa545fdadb120dc92c9f18166e38a5c47b50fab9fce476e022855", + "host": "ec2.us-east-2.amazonaws.com", + "x-amz-date": "20200811T065522Z", + "x-amz-security-token": TOKEN, + }, + }, + ), + # POST request with session token credentials. + ( + "us-east-2", + "2020-08-11T06:55:22Z", + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + }, + { + "method": "POST", + "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + }, + { + "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=" + + ACCESS_KEY_ID + + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=73452984e4a880ffdc5c392355733ec3f5ba310d5e0609a89244440cadfe7a7a", + "host": "sts.us-east-2.amazonaws.com", + "x-amz-date": "20200811T065522Z", + "x-amz-security-token": TOKEN, + }, + }, + ), + # POST request with computed x-amz-date and no data. + ( + "us-east-2", + "2020-08-11T06:55:22Z", + {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY}, + { + "method": "POST", + "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + }, + { + "url": "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=" + + ACCESS_KEY_ID + + "/20200811/us-east-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=d095ba304919cd0d5570ba8a3787884ee78b860f268ed040ba23831d55536d56", + "host": "sts.us-east-2.amazonaws.com", + "x-amz-date": "20200811T065522Z", + }, + }, + ), + # POST request with session token and additional headers/data. + ( + "us-east-2", + "2020-08-11T06:55:22Z", + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + }, + { + "method": "POST", + "url": "https://dynamodb.us-east-2.amazonaws.com/", + "headers": { + "Content-Type": "application/x-amz-json-1.0", + "x-amz-target": "DynamoDB_20120810.CreateTable", + }, + "data": REQUEST_PARAMS, + }, + { + "url": "https://dynamodb.us-east-2.amazonaws.com/", + "method": "POST", + "headers": { + "Authorization": "AWS4-HMAC-SHA256 Credential=" + + ACCESS_KEY_ID + + "/20200811/us-east-2/dynamodb/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-security-token;x-amz-target, Signature=fdaa5b9cc9c86b80fe61eaf504141c0b3523780349120f2bd8145448456e0385", + "host": "dynamodb.us-east-2.amazonaws.com", + "x-amz-date": "20200811T065522Z", + "Content-Type": "application/x-amz-json-1.0", + "x-amz-target": "DynamoDB_20120810.CreateTable", + "x-amz-security-token": TOKEN, + }, + "data": REQUEST_PARAMS, + }, + ), +] + + +class TestRequestSigner(object): + @pytest.mark.parametrize( + "region, time, credentials, original_request, signed_request", TEST_FIXTURES + ) + @mock.patch("google.auth._helpers.utcnow") + def test_get_request_options( + self, utcnow, region, time, credentials, original_request, signed_request + ): + utcnow.return_value = datetime.datetime.strptime(time, "%Y-%m-%dT%H:%M:%SZ") + request_signer = aws.RequestSigner(region) + actual_signed_request = request_signer.get_request_options( + credentials, + original_request.get("url"), + original_request.get("method"), + original_request.get("data"), + original_request.get("headers"), + ) + + assert actual_signed_request == signed_request + + def test_get_request_options_with_missing_scheme_url(self): + request_signer = aws.RequestSigner("us-east-2") + + with pytest.raises(ValueError) as excinfo: + request_signer.get_request_options( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + }, + "invalid", + "POST", + ) + + assert excinfo.match(r"Invalid AWS service URL") + + def test_get_request_options_with_invalid_scheme_url(self): + request_signer = aws.RequestSigner("us-east-2") + + with pytest.raises(ValueError) as excinfo: + request_signer.get_request_options( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + }, + "http://invalid", + "POST", + ) + + assert excinfo.match(r"Invalid AWS service URL") + + def test_get_request_options_with_missing_hostname_url(self): + request_signer = aws.RequestSigner("us-east-2") + + with pytest.raises(ValueError) as excinfo: + request_signer.get_request_options( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + }, + "https://", + "POST", + ) + + assert excinfo.match(r"Invalid AWS service URL") + + +class TestCredentials(object): + AWS_REGION = "us-east-2" + AWS_ROLE = "gcp-aws-role" + AWS_SECURITY_CREDENTIALS_RESPONSE = { + "AccessKeyId": ACCESS_KEY_ID, + "SecretAccessKey": SECRET_ACCESS_KEY, + "Token": TOKEN, + } + AWS_SIGNATURE_TIME = "2020-08-11T06:55:22Z" + CREDENTIAL_SOURCE = { + "environment_id": "aws1", + "region_url": REGION_URL, + "url": SECURITY_CREDS_URL, + "regional_cred_verification_url": CRED_VERIFICATION_URL, + } + SUCCESS_RESPONSE = { + "access_token": "ACCESS_TOKEN", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": " ".join(SCOPES), + } + + @classmethod + def make_serialized_aws_signed_request( + cls, + aws_security_credentials, + region_name="us-east-2", + url="https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + ): + """Utility to generate serialize AWS signed requests. + This makes it easy to assert generated subject tokens based on the + provided AWS security credentials, regions and AWS STS endpoint. + """ + request_signer = aws.RequestSigner(region_name) + signed_request = request_signer.get_request_options( + aws_security_credentials, url, "POST" + ) + reformatted_signed_request = { + "url": signed_request.get("url"), + "method": signed_request.get("method"), + "headers": [ + { + "key": "Authorization", + "value": signed_request.get("headers").get("Authorization"), + }, + {"key": "host", "value": signed_request.get("headers").get("host")}, + { + "key": "x-amz-date", + "value": signed_request.get("headers").get("x-amz-date"), + }, + ], + } + # Include security token if available. + if "security_token" in aws_security_credentials: + reformatted_signed_request.get("headers").append( + { + "key": "x-amz-security-token", + "value": signed_request.get("headers").get("x-amz-security-token"), + } + ) + # Append x-goog-cloud-target-resource header. + reformatted_signed_request.get("headers").append( + {"key": "x-goog-cloud-target-resource", "value": AUDIENCE} + ), + return urllib.parse.quote( + json.dumps( + reformatted_signed_request, separators=(",", ":"), sort_keys=True + ) + ) + + @classmethod + def make_mock_request( + cls, + region_status=None, + region_name=None, + role_status=None, + role_name=None, + security_credentials_status=None, + security_credentials_data=None, + token_status=None, + token_data=None, + impersonation_status=None, + impersonation_data=None, + ): + """Utility function to generate a mock HTTP request object. + This will facilitate testing various edge cases by specify how the + various endpoints will respond while generating a Google Access token + in an AWS environment. + """ + responses = [] + if region_status: + # AWS region request. + region_response = mock.create_autospec(transport.Response, instance=True) + region_response.status = region_status + if region_name: + region_response.data = "{}b".format(region_name).encode("utf-8") + responses.append(region_response) + + if role_status: + # AWS role name request. + role_response = mock.create_autospec(transport.Response, instance=True) + role_response.status = role_status + if role_name: + role_response.data = role_name.encode("utf-8") + responses.append(role_response) + + if security_credentials_status: + # AWS security credentials request. + security_credentials_response = mock.create_autospec( + transport.Response, instance=True + ) + security_credentials_response.status = security_credentials_status + if security_credentials_data: + security_credentials_response.data = json.dumps( + security_credentials_data + ).encode("utf-8") + responses.append(security_credentials_response) + + if token_status: + # GCP token exchange request. + token_response = mock.create_autospec(transport.Response, instance=True) + token_response.status = token_status + token_response.data = json.dumps(token_data).encode("utf-8") + responses.append(token_response) + + if impersonation_status: + # Service account impersonation request. + impersonation_response = mock.create_autospec( + transport.Response, instance=True + ) + impersonation_response.status = impersonation_status + impersonation_response.data = json.dumps(impersonation_data).encode("utf-8") + responses.append(impersonation_response) + + request = mock.create_autospec(transport.Request) + request.side_effect = responses + + return request + + @classmethod + def make_credentials( + cls, + credential_source, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + service_account_impersonation_url=None, + ): + return aws.Credentials( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=service_account_impersonation_url, + credential_source=credential_source, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + ) + + @classmethod + def assert_aws_metadata_request_kwargs(cls, request_kwargs, url, headers=None): + assert request_kwargs["url"] == url + # All used AWS metadata server endpoints use GET HTTP method. + assert request_kwargs["method"] == "GET" + if headers: + assert request_kwargs["headers"] == headers + else: + assert "headers" not in request_kwargs + # None of the endpoints used require any data in request. + assert "body" not in request_kwargs + + @classmethod + def assert_token_request_kwargs( + cls, request_kwargs, headers, request_data, token_url=TOKEN_URL + ): + assert request_kwargs["url"] == token_url + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_tuples = urllib.parse.parse_qsl(request_kwargs["body"]) + assert len(body_tuples) == len(request_data.keys()) + for (k, v) in body_tuples: + assert v.decode("utf-8") == request_data[k.decode("utf-8")] + + @classmethod + def assert_impersonation_request_kwargs( + cls, + request_kwargs, + headers, + request_data, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + ): + assert request_kwargs["url"] == service_account_impersonation_url + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_json = json.loads(request_kwargs["body"].decode("utf-8")) + assert body_json == request_data + + @mock.patch.object(aws.Credentials, "__init__", return_value=None) + def test_from_info_full_options(self, mock_init): + credentials = aws.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "quota_project_id": QUOTA_PROJECT_ID, + "credential_source": self.CREDENTIAL_SOURCE, + } + ) + + # Confirm aws.Credentials instance initialized with the expected parameters. + assert isinstance(credentials, aws.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE, + quota_project_id=QUOTA_PROJECT_ID, + ) + + @mock.patch.object(aws.Credentials, "__init__", return_value=None) + def test_from_info_required_options_only(self, mock_init): + credentials = aws.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE, + } + ) + + # Confirm aws.Credentials instance initialized with the expected parameters. + assert isinstance(credentials, aws.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE, + quota_project_id=None, + ) + + @mock.patch.object(aws.Credentials, "__init__", return_value=None) + def test_from_file_full_options(self, mock_init, tmpdir): + info = { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "quota_project_id": QUOTA_PROJECT_ID, + "credential_source": self.CREDENTIAL_SOURCE, + } + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(info)) + credentials = aws.Credentials.from_file(str(config_file)) + + # Confirm aws.Credentials instance initialized with the expected parameters. + assert isinstance(credentials, aws.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE, + quota_project_id=QUOTA_PROJECT_ID, + ) + + @mock.patch.object(aws.Credentials, "__init__", return_value=None) + def test_from_file_required_options_only(self, mock_init, tmpdir): + info = { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE, + } + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(info)) + credentials = aws.Credentials.from_file(str(config_file)) + + # Confirm aws.Credentials instance initialized with the expected parameters. + assert isinstance(credentials, aws.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE, + quota_project_id=None, + ) + + def test_constructor_invalid_credential_source(self): + # Provide invalid credential source. + credential_source = {"unsupported": "value"} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"No valid AWS 'credential_source' provided") + + def test_constructor_invalid_environment_id(self): + # Provide invalid environment_id. + credential_source = self.CREDENTIAL_SOURCE.copy() + credential_source["environment_id"] = "azure1" + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"No valid AWS 'credential_source' provided") + + def test_constructor_missing_cred_verification_url(self): + # regional_cred_verification_url is a required field. + credential_source = self.CREDENTIAL_SOURCE.copy() + credential_source.pop("regional_cred_verification_url") + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"No valid AWS 'credential_source' provided") + + def test_constructor_invalid_environment_id_version(self): + # Provide an unsupported version. + credential_source = self.CREDENTIAL_SOURCE.copy() + credential_source["environment_id"] = "aws3" + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"aws version '3' is not supported in the current build.") + + def test_retrieve_subject_token_missing_region_url(self): + # When AWS_REGION envvar is not available, region_url is required for + # determining the current AWS region. + credential_source = self.CREDENTIAL_SOURCE.copy() + credential_source.pop("region_url") + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match(r"Unable to determine AWS region") + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_temp_creds_no_environment_vars( + self, utcnow + ): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + # Assert region request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[0].kwargs, REGION_URL + ) + # Assert role request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[1].kwargs, SECURITY_CREDS_URL + ) + # Assert security credentials request. + self.assert_aws_metadata_request_kwargs( + request.call_args_list[2].kwargs, + "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE), + {"Content-Type": "application/json"}, + ) + + # Retrieve subject_token again. Region should not be queried again. + new_request = self.make_mock_request( + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + ) + + credentials.retrieve_subject_token(new_request) + + # Only 2 requests should be sent as the region is cached. + assert len(new_request.call_args_list) == 2 + # Assert role request. + self.assert_aws_metadata_request_kwargs( + new_request.call_args_list[0].kwargs, SECURITY_CREDS_URL + ) + # Assert security credentials request. + self.assert_aws_metadata_request_kwargs( + new_request.call_args_list[1].kwargs, + "{}/{}".format(SECURITY_CREDS_URL, self.AWS_ROLE), + {"Content-Type": "application/json"}, + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_permanent_creds_no_environment_vars( + self, utcnow + ): + # Simualte a permanent credential without a session token is + # returned by the security-credentials endpoint. + security_creds_response = self.AWS_SECURITY_CREDENTIALS_RESPONSE.copy() + security_creds_response.pop("Token") + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=security_creds_response, + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == self.make_serialized_aws_signed_request( + {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY} + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_environment_vars(self, utcnow, monkeypatch): + monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID) + monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY) + monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN) + monkeypatch.setenv(environment_vars.AWS_REGION, self.AWS_REGION) + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_environment_vars_no_session_token( + self, utcnow, monkeypatch + ): + monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID) + monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY) + monkeypatch.setenv(environment_vars.AWS_REGION, self.AWS_REGION) + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == self.make_serialized_aws_signed_request( + {"access_key_id": ACCESS_KEY_ID, "secret_access_key": SECRET_ACCESS_KEY} + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_environment_vars_except_region( + self, utcnow, monkeypatch + ): + monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID) + monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY) + monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN) + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + # Region will be queried since it is not found in envvars. + request = self.make_mock_request( + region_status=http_client.OK, region_name=self.AWS_REGION + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + + def test_retrieve_subject_token_error_determining_aws_region(self): + # Simulate error in retrieving the AWS region. + request = self.make_mock_request(region_status=http_client.BAD_REQUEST) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(request) + + assert excinfo.match(r"Unable to retrieve AWS region") + + def test_retrieve_subject_token_error_determining_aws_role(self): + # Simulate error in retrieving the AWS role name. + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.BAD_REQUEST, + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(request) + + assert excinfo.match(r"Unable to retrieve AWS role name") + + def test_retrieve_subject_token_error_determining_security_creds_url(self): + # Simulate the security-credentials url is missing. This is needed for + # determining the AWS security credentials when not found in envvars. + credential_source = self.CREDENTIAL_SOURCE.copy() + credential_source.pop("url") + request = self.make_mock_request( + region_status=http_client.OK, region_name=self.AWS_REGION + ) + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(request) + + assert excinfo.match( + r"Unable to determine the AWS metadata server security credentials endpoint" + ) + + def test_retrieve_subject_token_error_determining_aws_security_creds(self): + # Simulate error in retrieving the AWS security credentials. + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.BAD_REQUEST, + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(request) + + assert excinfo.match(r"Unable to retrieve AWS security credentials") + + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_without_impersonation_ignore_default_scopes(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expected_subject_token = self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": " ".join(SCOPES), + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + token_status=http_client.OK, + token_data=self.SUCCESS_RESPONSE, + ) + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 4 + # Fourth request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[3].kwargs, token_headers, token_request_data + ) + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] + + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_without_impersonation_use_default_scopes(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expected_subject_token = self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": " ".join(SCOPES), + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + token_status=http_client.OK, + token_data=self.SUCCESS_RESPONSE, + ) + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE, + quota_project_id=QUOTA_PROJECT_ID, + scopes=None, + # Default scopes should be used since user specified scopes are none. + default_scopes=SCOPES, + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 4 + # Fourth request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[3].kwargs, token_headers, token_request_data + ) + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes is None + assert credentials.default_scopes == SCOPES + + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_impersonation_ignore_default_scopes(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + expected_subject_token = self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "https://www.googleapis.com/auth/iam", + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + "x-goog-user-project": QUOTA_PROJECT_ID, + } + impersonation_request_data = { + "delegates": None, + "scope": SCOPES, + "lifetime": "3600s", + } + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + token_status=http_client.OK, + token_data=self.SUCCESS_RESPONSE, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + quota_project_id=QUOTA_PROJECT_ID, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 5 + # Fourth request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[3].kwargs, token_headers, token_request_data + ) + # Fifth request should be sent to iamcredentials endpoint for service + # account impersonation. + self.assert_impersonation_request_kwargs( + request.call_args_list[4].kwargs, + impersonation_headers, + impersonation_request_data, + ) + assert credentials.token == impersonation_response["accessToken"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes == SCOPES + assert credentials.default_scopes == ["ignored"] + + @mock.patch("google.auth._helpers.utcnow") + def test_refresh_success_with_impersonation_use_default_scopes(self, utcnow): + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + expected_subject_token = self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic " + BASIC_AUTH_ENCODING, + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "https://www.googleapis.com/auth/iam", + "subject_token": expected_subject_token, + "subject_token_type": SUBJECT_TOKEN_TYPE, + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + "x-goog-user-project": QUOTA_PROJECT_ID, + } + impersonation_request_data = { + "delegates": None, + "scope": SCOPES, + "lifetime": "3600s", + } + request = self.make_mock_request( + region_status=http_client.OK, + region_name=self.AWS_REGION, + role_status=http_client.OK, + role_name=self.AWS_ROLE, + security_credentials_status=http_client.OK, + security_credentials_data=self.AWS_SECURITY_CREDENTIALS_RESPONSE, + token_status=http_client.OK, + token_data=self.SUCCESS_RESPONSE, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + quota_project_id=QUOTA_PROJECT_ID, + scopes=None, + # Default scopes should be used since user specified scopes are none. + default_scopes=SCOPES, + ) + + credentials.refresh(request) + + assert len(request.call_args_list) == 5 + # Fourth request should be sent to GCP STS endpoint. + self.assert_token_request_kwargs( + request.call_args_list[3].kwargs, token_headers, token_request_data + ) + # Fifth request should be sent to iamcredentials endpoint for service + # account impersonation. + self.assert_impersonation_request_kwargs( + request.call_args_list[4].kwargs, + impersonation_headers, + impersonation_request_data, + ) + assert credentials.token == impersonation_response["accessToken"] + assert credentials.quota_project_id == QUOTA_PROJECT_ID + assert credentials.scopes is None + assert credentials.default_scopes == SCOPES + + def test_refresh_with_retrieve_subject_token_error(self): + request = self.make_mock_request(region_status=http_client.BAD_REQUEST) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Unable to retrieve AWS region") diff --git a/tests/test_external_account.py b/tests/test_external_account.py new file mode 100644 index 000000000..42e53ecb5 --- /dev/null +++ b/tests/test_external_account.py @@ -0,0 +1,1095 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json + +import mock +import pytest +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import external_account +from google.auth import transport + + +CLIENT_ID = "username" +CLIENT_SECRET = "password" +# Base64 encoding of "username:password" +BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" +SERVICE_ACCOUNT_EMAIL = "service-1234@service-name.iam.gserviceaccount.com" + + +class CredentialsImpl(external_account.Credentials): + def __init__( + self, + audience, + subject_token_type, + token_url, + credential_source, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + ): + super(CredentialsImpl, self).__init__( + audience=audience, + subject_token_type=subject_token_type, + token_url=token_url, + credential_source=credential_source, + service_account_impersonation_url=service_account_impersonation_url, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + ) + self._counter = 0 + + def retrieve_subject_token(self, request): + counter = self._counter + self._counter += 1 + return "subject_token_{}".format(counter) + + +class TestCredentials(object): + TOKEN_URL = "https://sts.googleapis.com/v1/token" + PROJECT_NUMBER = "123456" + POOL_ID = "POOL_ID" + PROVIDER_ID = "PROVIDER_ID" + AUDIENCE = ( + "//iam.googleapis.com/projects/{}" + "/locations/global/workloadIdentityPools/{}" + "/providers/{}" + ).format(PROJECT_NUMBER, POOL_ID, PROVIDER_ID) + SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" + CREDENTIAL_SOURCE = {"file": "/var/run/secrets/goog.id/token"} + SUCCESS_RESPONSE = { + "access_token": "ACCESS_TOKEN", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": "scope1 scope2", + } + ERROR_RESPONSE = { + "error": "invalid_request", + "error_description": "Invalid subject token", + "error_uri": "https://tools.ietf.org/html/rfc6749", + } + QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID" + SERVICE_ACCOUNT_IMPERSONATION_URL = ( + "https://us-east1-iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL) + ) + SCOPES = ["scope1", "scope2"] + IMPERSONATION_ERROR_RESPONSE = { + "error": { + "code": 400, + "message": "Request contains an invalid argument", + "status": "INVALID_ARGUMENT", + } + } + PROJECT_ID = "my-proj-id" + CLOUD_RESOURCE_MANAGER_URL = ( + "https://cloudresourcemanager.googleapis.com/v1/projects/" + ) + CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE = { + "projectNumber": PROJECT_NUMBER, + "projectId": PROJECT_ID, + "lifecycleState": "ACTIVE", + "name": "project-name", + "createTime": "2018-11-06T04:42:54.109Z", + "parent": {"type": "folder", "id": "12345678901"}, + } + + @classmethod + def make_credentials( + cls, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + service_account_impersonation_url=None, + ): + return CredentialsImpl( + audience=cls.AUDIENCE, + subject_token_type=cls.SUBJECT_TOKEN_TYPE, + token_url=cls.TOKEN_URL, + service_account_impersonation_url=service_account_impersonation_url, + credential_source=cls.CREDENTIAL_SOURCE, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + ) + + @classmethod + def make_mock_request( + cls, + status=http_client.OK, + data=None, + impersonation_status=None, + impersonation_data=None, + cloud_resource_manager_status=None, + cloud_resource_manager_data=None, + ): + # STS token exchange request. + token_response = mock.create_autospec(transport.Response, instance=True) + token_response.status = status + token_response.data = json.dumps(data).encode("utf-8") + responses = [token_response] + + # If service account impersonation is requested, mock the expected response. + if impersonation_status: + impersonation_response = mock.create_autospec( + transport.Response, instance=True + ) + impersonation_response.status = impersonation_status + impersonation_response.data = json.dumps(impersonation_data).encode("utf-8") + responses.append(impersonation_response) + + # If cloud resource manager is requested, mock the expected response. + if cloud_resource_manager_status: + cloud_resource_manager_response = mock.create_autospec( + transport.Response, instance=True + ) + cloud_resource_manager_response.status = cloud_resource_manager_status + cloud_resource_manager_response.data = json.dumps( + cloud_resource_manager_data + ).encode("utf-8") + responses.append(cloud_resource_manager_response) + + request = mock.create_autospec(transport.Request) + request.side_effect = responses + + return request + + @classmethod + def assert_token_request_kwargs(cls, request_kwargs, headers, request_data): + assert request_kwargs["url"] == cls.TOKEN_URL + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_tuples = urllib.parse.parse_qsl(request_kwargs["body"]) + for (k, v) in body_tuples: + assert v.decode("utf-8") == request_data[k.decode("utf-8")] + assert len(body_tuples) == len(request_data.keys()) + + @classmethod + def assert_impersonation_request_kwargs(cls, request_kwargs, headers, request_data): + assert request_kwargs["url"] == cls.SERVICE_ACCOUNT_IMPERSONATION_URL + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_json = json.loads(request_kwargs["body"].decode("utf-8")) + assert body_json == request_data + + @classmethod + def assert_resource_manager_request_kwargs( + cls, request_kwargs, project_number, headers + ): + assert request_kwargs["url"] == cls.CLOUD_RESOURCE_MANAGER_URL + project_number + assert request_kwargs["method"] == "GET" + assert request_kwargs["headers"] == headers + assert "body" not in request_kwargs + + def test_default_state(self): + credentials = self.make_credentials() + + # Not token acquired yet + assert not credentials.token + assert not credentials.valid + # Expiration hasn't been set yet + assert not credentials.expiry + assert not credentials.expired + # Scopes are required + assert not credentials.scopes + assert credentials.requires_scopes + assert not credentials.quota_project_id + + def test_with_scopes(self): + credentials = self.make_credentials() + + assert not credentials.scopes + assert credentials.requires_scopes + + scoped_credentials = credentials.with_scopes(["email"]) + + assert scoped_credentials.has_scopes(["email"]) + assert not scoped_credentials.requires_scopes + + def test_with_scopes_using_user_and_default_scopes(self): + credentials = self.make_credentials() + + assert not credentials.scopes + assert credentials.requires_scopes + + scoped_credentials = credentials.with_scopes( + ["email"], default_scopes=["profile"] + ) + + assert scoped_credentials.has_scopes(["email"]) + assert not scoped_credentials.has_scopes(["profile"]) + assert not scoped_credentials.requires_scopes + assert scoped_credentials.scopes == ["email"] + assert scoped_credentials.default_scopes == ["profile"] + + def test_with_scopes_using_default_scopes_only(self): + credentials = self.make_credentials() + + assert not credentials.scopes + assert credentials.requires_scopes + + scoped_credentials = credentials.with_scopes(None, default_scopes=["profile"]) + + assert scoped_credentials.has_scopes(["profile"]) + assert not scoped_credentials.requires_scopes + + def test_with_scopes_full_options_propagated(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + quota_project_id=self.QUOTA_PROJECT_ID, + scopes=self.SCOPES, + default_scopes=["default1"], + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + ) + + with mock.patch.object( + external_account.Credentials, "__init__", return_value=None + ) as mock_init: + credentials.with_scopes(["email"], ["default2"]) + + # Confirm with_scopes initialized the credential with the expected + # parameters and scopes. + mock_init.assert_called_once_with( + audience=self.AUDIENCE, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + quota_project_id=self.QUOTA_PROJECT_ID, + scopes=["email"], + default_scopes=["default2"], + ) + + def test_with_quota_project(self): + credentials = self.make_credentials() + + assert not credentials.scopes + assert not credentials.quota_project_id + + quota_project_creds = credentials.with_quota_project("project-foo") + + assert quota_project_creds.quota_project_id == "project-foo" + + def test_with_quota_project_full_options_propagated(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + quota_project_id=self.QUOTA_PROJECT_ID, + scopes=self.SCOPES, + default_scopes=["default1"], + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + ) + + with mock.patch.object( + external_account.Credentials, "__init__", return_value=None + ) as mock_init: + credentials.with_quota_project("project-foo") + + # Confirm with_quota_project initialized the credential with the + # expected parameters and quota project ID. + mock_init.assert_called_once_with( + audience=self.AUDIENCE, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + quota_project_id="project-foo", + scopes=self.SCOPES, + default_scopes=["default1"], + ) + + def test_with_invalid_impersonation_target_principal(self): + invalid_url = "https://iamcredentials.googleapis.com/v1/invalid" + + with pytest.raises(exceptions.RefreshError) as excinfo: + self.make_credentials(service_account_impersonation_url=invalid_url) + + assert excinfo.match( + r"Unable to determine target principal from service account impersonation URL." + ) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_without_client_auth_success(self, unused_utcnow): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request(status=http_client.OK, data=response) + credentials = self.make_credentials() + + credentials.refresh(request) + + self.assert_token_request_kwargs( + request.call_args.kwargs, headers, request_data + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + + def test_refresh_impersonation_without_client_auth_success(self): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=token_response, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0].kwargs, token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1].kwargs, + impersonation_headers, + impersonation_request_data, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + + def test_refresh_without_client_auth_success_explicit_user_scopes_ignore_default_scopes( + self + ): + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "scope1 scope2", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_credentials( + scopes=["scope1", "scope2"], + # Default scopes will be ignored in favor of user scopes. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs( + request.call_args.kwargs, headers, request_data + ) + assert credentials.valid + assert not credentials.expired + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + assert credentials.has_scopes(["scope1", "scope2"]) + assert not credentials.has_scopes(["ignored"]) + + def test_refresh_without_client_auth_success_explicit_default_scopes_only(self): + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "scope1 scope2", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_credentials( + scopes=None, + # Default scopes will be used since user scopes are none. + default_scopes=["scope1", "scope2"], + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs( + request.call_args.kwargs, headers, request_data + ) + assert credentials.valid + assert not credentials.expired + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + assert credentials.has_scopes(["scope1", "scope2"]) + + def test_refresh_without_client_auth_error(self): + request = self.make_mock_request( + status=http_client.BAD_REQUEST, data=self.ERROR_RESPONSE + ) + credentials = self.make_credentials() + + with pytest.raises(exceptions.OAuthError) as excinfo: + credentials.refresh(request) + + assert excinfo.match( + r"Error code invalid_request: Invalid subject token - https://tools.ietf.org/html/rfc6749" + ) + assert not credentials.expired + assert credentials.token is None + + def test_refresh_impersonation_without_client_auth_error(self): + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE, + impersonation_status=http_client.BAD_REQUEST, + impersonation_data=self.IMPERSONATION_ERROR_RESPONSE, + ) + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(request) + + assert excinfo.match(r"Unable to acquire impersonated credentials") + assert not credentials.expired + assert credentials.token is None + + def test_refresh_with_client_auth_success(self): + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_credentials( + client_id=CLIENT_ID, client_secret=CLIENT_SECRET + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs( + request.call_args.kwargs, headers, request_data + ) + assert credentials.valid + assert not credentials.expired + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + + def test_refresh_impersonation_with_client_auth_success_ignore_default_scopes(self): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=token_response, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + # Default scopes will be ignored since user scopes are specified. + default_scopes=["ignored"], + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0].kwargs, token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1].kwargs, + impersonation_headers, + impersonation_request_data, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + + def test_refresh_impersonation_with_client_auth_success_use_default_scopes(self): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=token_response, + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=None, + # Default scopes will be used since user specified scopes are none. + default_scopes=self.SCOPES, + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0].kwargs, token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1].kwargs, + impersonation_headers, + impersonation_request_data, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + + def test_apply_without_quota_project_id(self): + headers = {} + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_credentials() + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) + } + + def test_apply_impersonation_without_quota_project_id(self): + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + ) + headers = {} + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "authorization": "Bearer {}".format(impersonation_response["accessToken"]) + } + + def test_apply_with_quota_project_id(self): + headers = {"other": "header-value"} + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_credentials(quota_project_id=self.QUOTA_PROJECT_ID) + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + "x-goog-user-project": self.QUOTA_PROJECT_ID, + } + + def test_apply_impersonation_with_quota_project_id(self): + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + quota_project_id=self.QUOTA_PROJECT_ID, + ) + headers = {"other": "header-value"} + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(impersonation_response["accessToken"]), + "x-goog-user-project": self.QUOTA_PROJECT_ID, + } + + def test_before_request(self): + headers = {"other": "header-value"} + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_credentials() + + # First call should call refresh, setting the token. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + } + + # Second call shouldn't call refresh. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + } + + def test_before_request_impersonation(self): + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + headers = {"other": "header-value"} + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL + ) + + # First call should call refresh, setting the token. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(impersonation_response["accessToken"]), + } + + # Second call shouldn't call refresh. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(impersonation_response["accessToken"]), + } + + @mock.patch("google.auth._helpers.utcnow") + def test_before_request_expired(self, utcnow): + headers = {} + request = self.make_mock_request( + status=http_client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_credentials() + credentials.token = "token" + utcnow.return_value = datetime.datetime.min + # Set the expiration to one second more than now plus the clock skew + # accomodation. These credentials should be valid. + credentials.expiry = ( + datetime.datetime.min + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1) + ) + + assert credentials.valid + assert not credentials.expired + + credentials.before_request(request, "POST", "https://example.com/api", headers) + + # Cached token should be used. + assert headers == {"authorization": "Bearer token"} + + # Next call should simulate 1 second passed. + utcnow.return_value = datetime.datetime.min + datetime.timedelta(seconds=1) + + assert not credentials.valid + assert credentials.expired + + credentials.before_request(request, "POST", "https://example.com/api", headers) + + # New token should be retrieved. + assert headers == { + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) + } + + @mock.patch("google.auth._helpers.utcnow") + def test_before_request_impersonation_expired(self, utcnow): + headers = {} + expire_time = ( + datetime.datetime.min + datetime.timedelta(seconds=3601) + ).isoformat("T") + "Z" + # Service account impersonation response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + ) + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL + ) + credentials.token = "token" + utcnow.return_value = datetime.datetime.min + # Set the expiration to one second more than now plus the clock skew + # accomodation. These credentials should be valid. + credentials.expiry = ( + datetime.datetime.min + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1) + ) + + assert credentials.valid + assert not credentials.expired + + credentials.before_request(request, "POST", "https://example.com/api", headers) + + # Cached token should be used. + assert headers == {"authorization": "Bearer token"} + + # Next call should simulate 1 second passed. This will trigger the expiration + # threshold. + utcnow.return_value = datetime.datetime.min + datetime.timedelta(seconds=1) + + assert not credentials.valid + assert credentials.expired + + credentials.before_request(request, "POST", "https://example.com/api", headers) + + # New token should be retrieved. + assert headers == { + "authorization": "Bearer {}".format(impersonation_response["accessToken"]) + } + + @pytest.mark.parametrize( + "audience", + [ + # Legacy K8s audience format. + "identitynamespace:1f12345:my_provider", + # Unrealistic audiences. + "//iam.googleapis.com/projects", + "//iam.googleapis.com/projects/", + "//iam.googleapis.com/project/123456", + "//iam.googleapis.com/projects//123456", + "//iam.googleapis.com/prefix_projects/123456", + "//iam.googleapis.com/projects_suffix/123456", + ], + ) + def test_project_number_indeterminable(self, audience): + credentials = CredentialsImpl( + audience=audience, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + ) + + assert credentials.project_number is None + assert credentials.get_project_id(None) is None + + def test_project_number_determinable(self): + credentials = CredentialsImpl( + audience=self.AUDIENCE, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + ) + + assert credentials.project_number == self.PROJECT_NUMBER + + def test_project_id_without_scopes(self): + # Initialize credentials with no scopes. + credentials = CredentialsImpl( + audience=self.AUDIENCE, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + ) + + assert credentials.get_project_id(None) is None + + def test_get_project_id_cloud_resource_manager_success(self): + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + } + # Service account impersonation request/response. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "x-goog-user-project": self.QUOTA_PROJECT_ID, + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange, service account + # impersonation and cloud resource manager request. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + impersonation_status=http_client.OK, + impersonation_data=impersonation_response, + cloud_resource_manager_status=http_client.OK, + cloud_resource_manager_data=self.CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE, + ) + credentials = self.make_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + quota_project_id=self.QUOTA_PROJECT_ID, + ) + + # Expected project ID from cloud resource manager response should be returned. + project_id = credentials.get_project_id(request) + + assert project_id == self.PROJECT_ID + # 3 requests should be processed. + assert len(request.call_args_list) == 3 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0].kwargs, token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1].kwargs, + impersonation_headers, + impersonation_request_data, + ) + # In the process of getting project ID, an access token should be + # retrieved. + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + # Verify cloud resource manager request parameters. + self.assert_resource_manager_request_kwargs( + request.call_args_list[2].kwargs, + self.PROJECT_NUMBER, + { + "x-goog-user-project": self.QUOTA_PROJECT_ID, + "authorization": "Bearer {}".format( + impersonation_response["accessToken"] + ), + }, + ) + + # Calling get_project_id again should return the cached project_id. + project_id = credentials.get_project_id(request) + + assert project_id == self.PROJECT_ID + # No additional requests. + assert len(request.call_args_list) == 3 + + def test_get_project_id_cloud_resource_manager_error(self): + # Simulate resource doesn't have sufficient permissions to access + # cloud resource manager. + request = self.make_mock_request( + status=http_client.OK, + data=self.SUCCESS_RESPONSE.copy(), + cloud_resource_manager_status=http_client.UNAUTHORIZED, + ) + credentials = self.make_credentials(scopes=self.SCOPES) + + project_id = credentials.get_project_id(request) + + assert project_id is None + # Only 2 requests to STS and cloud resource manager should be sent. + assert len(request.call_args_list) == 2 diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py new file mode 100644 index 000000000..c017ab59f --- /dev/null +++ b/tests/test_identity_pool.py @@ -0,0 +1,873 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os + +import mock +import pytest +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import identity_pool +from google.auth import transport + + +CLIENT_ID = "username" +CLIENT_SECRET = "password" +# Base64 encoding of "username:password". +BASIC_AUTH_ENCODING = "dXNlcm5hbWU6cGFzc3dvcmQ=" +SERVICE_ACCOUNT_EMAIL = "service-1234@service-name.iam.gserviceaccount.com" +SERVICE_ACCOUNT_IMPERSONATION_URL = ( + "https://us-east1-iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL) +) +QUOTA_PROJECT_ID = "QUOTA_PROJECT_ID" +SCOPES = ["scope1", "scope2"] +DATA_DIR = os.path.join(os.path.dirname(__file__), "data") +SUBJECT_TOKEN_TEXT_FILE = os.path.join(DATA_DIR, "external_subject_token.txt") +SUBJECT_TOKEN_JSON_FILE = os.path.join(DATA_DIR, "external_subject_token.json") +SUBJECT_TOKEN_FIELD_NAME = "access_token" + +with open(SUBJECT_TOKEN_TEXT_FILE) as fh: + TEXT_FILE_SUBJECT_TOKEN = fh.read() + +with open(SUBJECT_TOKEN_JSON_FILE) as fh: + JSON_FILE_CONTENT = json.load(fh) + JSON_FILE_SUBJECT_TOKEN = JSON_FILE_CONTENT.get(SUBJECT_TOKEN_FIELD_NAME) + +TOKEN_URL = "https://sts.googleapis.com/v1/token" +SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" +AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID" + + +class TestCredentials(object): + CREDENTIAL_SOURCE_TEXT = {"file": SUBJECT_TOKEN_TEXT_FILE} + CREDENTIAL_SOURCE_JSON = { + "file": SUBJECT_TOKEN_JSON_FILE, + "format": {"type": "json", "subject_token_field_name": "access_token"}, + } + CREDENTIAL_URL = "http://fakeurl.com" + CREDENTIAL_SOURCE_TEXT_URL = {"url": CREDENTIAL_URL} + CREDENTIAL_SOURCE_JSON_URL = { + "url": CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "access_token"}, + } + SUCCESS_RESPONSE = { + "access_token": "ACCESS_TOKEN", + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + "expires_in": 3600, + "scope": " ".join(SCOPES), + } + + @classmethod + def make_mock_response(cls, status, data): + response = mock.create_autospec(transport.Response, instance=True) + response.status = status + if isinstance(data, dict): + response.data = json.dumps(data).encode("utf-8") + else: + response.data = data + return response + + @classmethod + def make_mock_request( + cls, token_status=http_client.OK, token_data=None, *extra_requests + ): + responses = [] + responses.append(cls.make_mock_response(token_status, token_data)) + + while len(extra_requests) > 0: + # If service account impersonation is requested, mock the expected response. + status, data, extra_requests = ( + extra_requests[0], + extra_requests[1], + extra_requests[2:], + ) + responses.append(cls.make_mock_response(status, data)) + + request = mock.create_autospec(transport.Request) + request.side_effect = responses + + return request + + @classmethod + def assert_credential_request_kwargs( + cls, request_kwargs, headers, url=CREDENTIAL_URL + ): + assert request_kwargs["url"] == url + assert request_kwargs["method"] == "GET" + assert request_kwargs["headers"] == headers + assert request_kwargs.get("body", None) is None + + @classmethod + def assert_token_request_kwargs( + cls, request_kwargs, headers, request_data, token_url=TOKEN_URL + ): + assert request_kwargs["url"] == token_url + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_tuples = urllib.parse.parse_qsl(request_kwargs["body"]) + assert len(body_tuples) == len(request_data.keys()) + for (k, v) in body_tuples: + assert v.decode("utf-8") == request_data[k.decode("utf-8")] + + @classmethod + def assert_impersonation_request_kwargs( + cls, + request_kwargs, + headers, + request_data, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + ): + assert request_kwargs["url"] == service_account_impersonation_url + assert request_kwargs["method"] == "POST" + assert request_kwargs["headers"] == headers + assert request_kwargs["body"] is not None + body_json = json.loads(request_kwargs["body"].decode("utf-8")) + assert body_json == request_data + + @classmethod + def assert_underlying_credentials_refresh( + cls, + credentials, + audience, + subject_token, + subject_token_type, + token_url, + service_account_impersonation_url=None, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=None, + credential_data=None, + scopes=None, + default_scopes=None, + ): + """Utility to assert that a credentials are initialized with the expected + attributes by calling refresh functionality and confirming response matches + expected one and that the underlying requests were populated with the + expected parameters. + """ + # STS token exchange request/response. + token_response = cls.SUCCESS_RESPONSE.copy() + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + if basic_auth_encoding: + token_headers["Authorization"] = "Basic " + basic_auth_encoding + + if service_account_impersonation_url: + token_scopes = "https://www.googleapis.com/auth/iam" + else: + token_scopes = " ".join(used_scopes or []) + + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": audience, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": token_scopes, + "subject_token": subject_token, + "subject_token_type": subject_token_type, + } + + if service_account_impersonation_url: + # Service account impersonation request/response. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + + datetime.timedelta(seconds=3600) + ).isoformat("T") + "Z" + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": used_scopes, + "lifetime": "3600s", + } + + # Initialize mock request to handle token retrieval, token exchange and + # service account impersonation request. + requests = [] + if credential_data: + requests.append((http_client.OK, credential_data)) + + token_request_index = len(requests) + requests.append((http_client.OK, token_response)) + + if service_account_impersonation_url: + impersonation_request_index = len(requests) + requests.append((http_client.OK, impersonation_response)) + + request = cls.make_mock_request(*[el for req in requests for el in req]) + + credentials.refresh(request) + + assert len(request.call_args_list) == len(requests) + if credential_data: + cls.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) + # Verify token exchange request parameters. + cls.assert_token_request_kwargs( + request.call_args_list[token_request_index].kwargs, + token_headers, + token_request_data, + token_url, + ) + # Verify service account impersonation request parameters if the request + # is processed. + if service_account_impersonation_url: + cls.assert_impersonation_request_kwargs( + request.call_args_list[impersonation_request_index].kwargs, + impersonation_headers, + impersonation_request_data, + service_account_impersonation_url, + ) + assert credentials.token == impersonation_response["accessToken"] + else: + assert credentials.token == token_response["access_token"] + assert credentials.quota_project_id == quota_project_id + assert credentials.scopes == scopes + assert credentials.default_scopes == default_scopes + + @classmethod + def make_credentials( + cls, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + service_account_impersonation_url=None, + credential_source=None, + ): + return identity_pool.Credentials( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=service_account_impersonation_url, + credential_source=credential_source, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_info_full_options(self, mock_init): + credentials = identity_pool.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "quota_project_id": QUOTA_PROJECT_ID, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + } + ) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=QUOTA_PROJECT_ID, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_info_required_options_only(self, mock_init): + credentials = identity_pool.Credentials.from_info( + { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + } + ) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=None, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_file_full_options(self, mock_init, tmpdir): + info = { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "service_account_impersonation_url": SERVICE_ACCOUNT_IMPERSONATION_URL, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "quota_project_id": QUOTA_PROJECT_ID, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + } + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(info)) + credentials = identity_pool.Credentials.from_file(str(config_file)) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=QUOTA_PROJECT_ID, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_file_required_options_only(self, mock_init, tmpdir): + info = { + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + } + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(info)) + credentials = identity_pool.Credentials.from_file(str(config_file)) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=None, + ) + + def test_constructor_invalid_options(self): + credential_source = {"unsupported": "value"} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Missing credential_source") + + def test_constructor_invalid_options_url_and_file(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "file": SUBJECT_TOKEN_TEXT_FILE, + } + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Ambiguous credential_source") + + def test_constructor_invalid_options_environment_id(self): + credential_source = {"url": self.CREDENTIAL_URL, "environment_id": "aws1"} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match( + r"Invalid Identity Pool credential_source field 'environment_id'" + ) + + def test_constructor_invalid_credential_source(self): + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source="non-dict") + + assert excinfo.match(r"Missing credential_source") + + def test_constructor_invalid_credential_source_format_type(self): + credential_source = {"format": {"type": "xml"}} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match(r"Invalid credential_source format 'xml'") + + def test_constructor_missing_subject_token_field_name(self): + credential_source = {"format": {"type": "json"}} + + with pytest.raises(ValueError) as excinfo: + self.make_credentials(credential_source=credential_source) + + assert excinfo.match( + r"Missing subject_token_field_name for JSON credential_source format" + ) + + def test_retrieve_subject_token_missing_subject_token(self, tmpdir): + # Provide empty text file. + empty_file = tmpdir.join("empty.txt") + empty_file.write("") + credential_source = {"file": str(empty_file)} + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match(r"Missing subject_token in the credential_source file") + + def test_retrieve_subject_token_text_file(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT + ) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == TEXT_FILE_SUBJECT_TOKEN + + def test_retrieve_subject_token_json_file(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_JSON + ) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + + def test_retrieve_subject_token_json_file_invalid_field_name(self): + credential_source = { + "file": SUBJECT_TOKEN_JSON_FILE, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + SUBJECT_TOKEN_JSON_FILE, "not_found" + ) + ) + + def test_retrieve_subject_token_invalid_json(self, tmpdir): + # Provide JSON file. This should result in JSON parsing error. + invalid_json_file = tmpdir.join("invalid.json") + invalid_json_file.write("{") + credential_source = { + "file": str(invalid_json_file), + "format": {"type": "json", "subject_token_field_name": "access_token"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + str(invalid_json_file), "access_token" + ) + ) + + def test_retrieve_subject_token_file_not_found(self): + credential_source = {"file": "./not_found.txt"} + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(None) + + assert excinfo.match(r"File './not_found.txt' was not found") + + def test_refresh_text_file_success_without_impersonation_ignore_default_scopes( + self + ): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=["ignored"], + ) + + def test_refresh_text_file_success_without_impersonation_use_default_scopes(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=None, + # Default scopes should be used since user specified scopes are none. + default_scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=None, + default_scopes=SCOPES, + ) + + def test_refresh_text_file_success_with_impersonation_ignore_default_scopes(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + # Default scopes should be ignored. + default_scopes=["ignored"], + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=["ignored"], + ) + + def test_refresh_text_file_success_with_impersonation_use_default_scopes(self): + # Initialize credentials with service account impersonation, basic auth + # and default scopes (no user scopes). + credentials = self.make_credentials( + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=None, + # Default scopes should be used since user specified scopes are none. + default_scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=None, + default_scopes=SCOPES, + ) + + def test_refresh_json_file_success_without_impersonation(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + ) + + def test_refresh_json_file_success_with_impersonation(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + ) + + def test_refresh_with_retrieve_subject_token_error(self): + credential_source = { + "file": SUBJECT_TOKEN_JSON_FILE, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(None) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + SUBJECT_TOKEN_JSON_FILE, "not_found" + ) + ) + + def test_retrieve_subject_token_from_url(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL + ) + request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == TEXT_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) + + def test_retrieve_subject_token_from_url_with_headers(self): + credentials = self.make_credentials( + credential_source={"url": self.CREDENTIAL_URL, "headers": {"foo": "bar"}} + ) + request = self.make_mock_request(token_data=TEXT_FILE_SUBJECT_TOKEN) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == TEXT_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, {"foo": "bar"} + ) + + def test_retrieve_subject_token_from_url_json(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_JSON_URL + ) + request = self.make_mock_request(token_data=JSON_FILE_CONTENT) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs(request.call_args_list[0].kwargs, None) + + def test_retrieve_subject_token_from_url_json_with_headers(self): + credentials = self.make_credentials( + credential_source={ + "url": self.CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "access_token"}, + "headers": {"foo": "bar"}, + } + ) + request = self.make_mock_request(token_data=JSON_FILE_CONTENT) + subject_token = credentials.retrieve_subject_token(request) + + assert subject_token == JSON_FILE_SUBJECT_TOKEN + self.assert_credential_request_kwargs( + request.call_args_list[0].kwargs, {"foo": "bar"} + ) + + def test_retrieve_subject_token_from_url_not_found(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL + ) + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token( + self.make_mock_request(token_status=404, token_data=JSON_FILE_CONTENT) + ) + + assert excinfo.match("Unable to retrieve Identity Pool subject token") + + def test_retrieve_subject_token_from_url_json_invalid_field(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token( + self.make_mock_request(token_data=JSON_FILE_CONTENT) + ) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + self.CREDENTIAL_URL, "not_found" + ) + ) + + def test_retrieve_subject_token_from_url_json_invalid_format(self): + credentials = self.make_credentials( + credential_source=self.CREDENTIAL_SOURCE_JSON_URL + ) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.retrieve_subject_token(self.make_mock_request(token_data="{")) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + self.CREDENTIAL_URL, "access_token" + ) + ) + + def test_refresh_text_file_success_without_impersonation_url(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + credential_data=TEXT_FILE_SUBJECT_TOKEN, + ) + + def test_refresh_text_file_success_with_impersonation_url(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + credential_data=TEXT_FILE_SUBJECT_TOKEN, + ) + + def test_refresh_json_file_success_without_impersonation_url(self): + credentials = self.make_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + credential_data=JSON_FILE_CONTENT, + ) + + def test_refresh_json_file_success_with_impersonation_url(self): + # Initialize credentials with service account impersonation and basic auth. + credentials = self.make_credentials( + # Test with JSON format type. + credential_source=self.CREDENTIAL_SOURCE_JSON_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=SCOPES, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=AUDIENCE, + subject_token=JSON_FILE_SUBJECT_TOKEN, + subject_token_type=SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + default_scopes=None, + credential_data=JSON_FILE_CONTENT, + ) + + def test_refresh_with_retrieve_subject_token_error_url(self): + credential_source = { + "url": self.CREDENTIAL_URL, + "format": {"type": "json", "subject_token_field_name": "not_found"}, + } + credentials = self.make_credentials(credential_source=credential_source) + + with pytest.raises(exceptions.RefreshError) as excinfo: + credentials.refresh(self.make_mock_request(token_data=JSON_FILE_CONTENT)) + + assert excinfo.match( + "Unable to parse subject_token from JSON file '{}' using key '{}'".format( + self.CREDENTIAL_URL, "not_found" + ) + ) diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py index 305f93926..430c770d3 100644 --- a/tests/test_impersonated_credentials.py +++ b/tests/test_impersonated_credentials.py @@ -104,12 +104,17 @@ class TestImpersonatedCredentials(object): SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI ) USER_SOURCE_CREDENTIALS = credentials.Credentials(token="ABCDE") + IAM_ENDPOINT_OVERRIDE = ( + "https://us-east1-iamcredentials.googleapis.com/v1/projects/-" + + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL) + ) def make_credentials( self, source_credentials=SOURCE_CREDENTIALS, lifetime=LIFETIME, target_principal=TARGET_PRINCIPAL, + iam_endpoint_override=None, ): return Credentials( @@ -118,6 +123,7 @@ def make_credentials( target_scopes=self.TARGET_SCOPES, delegates=self.DELEGATES, lifetime=lifetime, + iam_endpoint_override=iam_endpoint_override, ) def test_make_from_user_credentials(self): @@ -172,6 +178,34 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials): assert credentials.valid assert not credentials.expired + @pytest.mark.parametrize("use_data_bytes", [True, False]) + def test_refresh_success_iam_endpoint_override( + self, use_data_bytes, mock_donor_credentials + ): + credentials = self.make_credentials( + lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE + ) + token = "token" + + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) + ).isoformat("T") + "Z" + response_body = {"accessToken": token, "expireTime": expire_time} + + request = self.make_request( + data=json.dumps(response_body), + status=http_client.OK, + use_data_bytes=use_data_bytes, + ) + + credentials.refresh(request) + + assert credentials.valid + assert not credentials.expired + # Confirm override endpoint used. + request_kwargs = request.call_args.kwargs + assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE + @pytest.mark.parametrize("time_skew", [100, -100]) def test_refresh_source_credentials(self, time_skew): credentials = self.make_credentials(lifetime=None) @@ -317,6 +351,36 @@ def test_with_quota_project(self): quota_project_creds = credentials.with_quota_project("project-foo") assert quota_project_creds._quota_project_id == "project-foo" + @pytest.mark.parametrize("use_data_bytes", [True, False]) + def test_with_quota_project_iam_endpoint_override( + self, use_data_bytes, mock_donor_credentials + ): + credentials = self.make_credentials( + lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE + ) + token = "token" + # iam_endpoint_override should be copied to created credentials. + quota_project_creds = credentials.with_quota_project("project-foo") + + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) + ).isoformat("T") + "Z" + response_body = {"accessToken": token, "expireTime": expire_time} + + request = self.make_request( + data=json.dumps(response_body), + status=http_client.OK, + use_data_bytes=use_data_bytes, + ) + + quota_project_creds.refresh(request) + + assert quota_project_creds.valid + assert not quota_project_creds.expired + # Confirm override endpoint used. + request_kwargs = request.call_args.kwargs + assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE + def test_id_token_success( self, mock_donor_credentials, mock_authorizedsession_idtoken ):