diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 24b93b423..f588981a0 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -73,6 +73,7 @@ def __init__( quota_project_id=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): """Instantiates an external account credentials object. @@ -90,6 +91,11 @@ def __init__( authorization grant. default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. + workforce_pool_user_project (Optona[str]): The optional workforce pool user + project number when the credential corresponds to a workforce pool and not + a workload identity pool. The underlying principal must still have + serviceusage.services.use IAM permission to use the project for + billing/quota. Raises: google.auth.exceptions.RefreshError: If the generateAccessToken endpoint returned an error. @@ -105,6 +111,7 @@ def __init__( self._quota_project_id = quota_project_id self._scopes = scopes self._default_scopes = default_scopes + self._workforce_pool_user_project = workforce_pool_user_project if self._client_id: self._client_auth = utils.ClientAuthentication( @@ -120,6 +127,13 @@ def __init__( self._impersonated_credentials = None self._project_id = None + if not self.is_workforce_pool and self._workforce_pool_user_project: + # Workload identity pools do not support workforce pool user projects. + raise ValueError( + "workforce_pool_user_project should not be set for non-workforce pool " + "credentials" + ) + @property def info(self): """Generates the dictionary representation of the current credentials. @@ -140,6 +154,7 @@ def info(self): "quota_project_id": self._quota_project_id, "client_id": self._client_id, "client_secret": self._client_secret, + "workforce_pool_user_project": self._workforce_pool_user_project, } return {key: value for key, value in config_info.items() if value is not None} @@ -178,12 +193,23 @@ def is_user(self): # service account. if self._service_account_impersonation_url: return False + return self.is_workforce_pool + + @property + def is_workforce_pool(self): + """Returns whether the credentials represent a workforce pool (True) or + workload (False) based on the credentials' audience. + + This will also return True for impersonated workforce pool credentials. + + Returns: + bool: True if the credentials represent a workforce pool. False if they + represent a workload. + """ # Workforce pools representing users have the following audience format: # //iam.googleapis.com/locations/$location/workforcePools/$poolId/providers/$providerId p = re.compile(r"//iam\.googleapis\.com/locations/[^/]+/workforcePools/") - if p.match(self._audience): - return True - return False + return p.match(self._audience or "") is not None @property def requires_scopes(self): @@ -210,7 +236,7 @@ def project_number(self): @_helpers.copy_docstring(credentials.Scoped) def with_scopes(self, scopes, default_scopes=None): - return self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -221,7 +247,11 @@ def with_scopes(self, scopes, default_scopes=None): quota_project_id=self._quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + return self.__class__(**d) @abc.abstractmethod def retrieve_subject_token(self, request): @@ -238,7 +268,9 @@ def retrieve_subject_token(self, request): raise NotImplementedError("retrieve_subject_token must be implemented") def get_project_id(self, request): - """Retrieves the project ID corresponding to the workload identity pool. + """Retrieves the project ID corresponding to the workload identity or workforce pool. + For workforce pool credentials, it returns the project ID corresponding to + the workforce_pool_user_project. When not determinable, None is returned. @@ -255,16 +287,17 @@ def get_project_id(self, request): HTTP requests. Returns: Optional[str]: The project ID corresponding to the workload identity pool - if determinable. + or workforce pool if determinable. """ if self._project_id: # If already retrieved, return the cached project ID value. return self._project_id scopes = self._scopes if self._scopes is not None else self._default_scopes # Scopes are required in order to retrieve a valid access token. - if self.project_number and scopes: + project_number = self.project_number or self._workforce_pool_user_project + if project_number and scopes: headers = {} - url = _CLOUD_RESOURCE_MANAGER + self.project_number + url = _CLOUD_RESOURCE_MANAGER + project_number self.before_request(request, "GET", url, headers) response = request(url=url, method="GET", headers=headers) @@ -291,6 +324,11 @@ def refresh(self, request): self.expiry = self._impersonated_credentials.expiry else: now = _helpers.utcnow() + additional_options = None + # Do not pass workforce_pool_user_project when client authentication + # is used. The client ID is sufficient for determining the user project. + if self._workforce_pool_user_project and not self._client_id: + additional_options = {"userProject": self._workforce_pool_user_project} response_data = self._sts_client.exchange_token( request=request, grant_type=_STS_GRANT_TYPE, @@ -299,6 +337,7 @@ def refresh(self, request): audience=self._audience, scopes=scopes, requested_token_type=_STS_REQUESTED_TOKEN_TYPE, + additional_options=additional_options, ) self.token = response_data.get("access_token") lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) @@ -307,7 +346,7 @@ def refresh(self, request): @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): # Return copy of instance with the provided quota project ID. - return self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -318,7 +357,11 @@ def with_quota_project(self, quota_project_id): quota_project_id=quota_project_id, scopes=self._scopes, default_scopes=self._default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + return self.__class__(**d) def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. @@ -336,7 +379,7 @@ def _initialize_impersonated_credentials(self): endpoint returned an error. """ # Return copy of instance with no service account impersonation. - source_credentials = self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -347,7 +390,11 @@ def _initialize_impersonated_credentials(self): quota_project_id=self._quota_project_id, scopes=self._scopes, default_scopes=self._default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + source_credentials = self.__class__(**d) # Determine target_principal. target_principal = self.service_account_email diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index c331e0921..901fd62fb 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -58,6 +58,7 @@ def __init__( quota_project_id=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): """Instantiates an external account credentials object from a file/URL. @@ -95,6 +96,11 @@ def __init__( authorization grant. default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. + workforce_pool_user_project (Optona[str]): The optional workforce pool user + project number when the credential corresponds to a workforce pool and not + a workload identity pool. The underlying principal must still have + serviceusage.services.use IAM permission to use the project for + billing/quota. Raises: google.auth.exceptions.RefreshError: If an error is encountered during @@ -117,6 +123,7 @@ def __init__( quota_project_id=quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, ) if not isinstance(credential_source, Mapping): self._credential_source_file = None @@ -255,6 +262,7 @@ def from_info(cls, info, **kwargs): client_secret=info.get("client_secret"), credential_source=info.get("credential_source"), quota_project_id=info.get("quota_project_id"), + workforce_pool_user_project=info.get("workforce_pool_user_project"), **kwargs ) diff --git a/tests/test_external_account.py b/tests/test_external_account.py index df6174f17..97f1564ef 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -37,6 +37,33 @@ "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", "//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id", ] +# Workload identity pool audiences or invalid workforce pool audiences. +TEST_NON_USER_AUDIENCES = [ + # Legacy K8s audience format. + "identitynamespace:1f12345:my_provider", + ( + "//iam.googleapis.com/projects/123456/locations/" + "global/workloadIdentityPools/pool-id/providers/" + "provider-id" + ), + ( + "//iam.googleapis.com/projects/123456/locations/" + "eu/workloadIdentityPools/pool-id/providers/" + "provider-id" + ), + # Pool ID with workforcePools string. + ( + "//iam.googleapis.com/projects/123456/locations/" + "global/workloadIdentityPools/workforcePools/providers/" + "provider-id" + ), + # Unrealistic / incorrect workforce pool audiences. + "//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", + "//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", + "//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", + "//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", + "//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", +] class CredentialsImpl(external_account.Credentials): @@ -52,6 +79,7 @@ def __init__( quota_project_id=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): super(CredentialsImpl, self).__init__( audience=audience, @@ -64,6 +92,7 @@ def __init__( quota_project_id=quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, ) self._counter = 0 @@ -83,7 +112,12 @@ class TestCredentials(object): "/locations/global/workloadIdentityPools/{}" "/providers/{}" ).format(PROJECT_NUMBER, POOL_ID, PROVIDER_ID) + WORKFORCE_AUDIENCE = ( + "//iam.googleapis.com/locations/global/workforcePools/{}/providers/{}" + ).format(POOL_ID, PROVIDER_ID) + WORKFORCE_POOL_USER_PROJECT = "WORKFORCE_POOL_USER_PROJECT_NUMBER" SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" + WORKFORCE_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:id_token" CREDENTIAL_SOURCE = {"file": "/var/run/secrets/goog.id/token"} SUCCESS_RESPONSE = { "access_token": "ACCESS_TOKEN", @@ -146,6 +180,31 @@ def make_credentials( default_scopes=default_scopes, ) + @classmethod + def make_workforce_pool_credentials( + cls, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + service_account_impersonation_url=None, + workforce_pool_user_project=None, + ): + return CredentialsImpl( + audience=cls.WORKFORCE_AUDIENCE, + subject_token_type=cls.WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=cls.TOKEN_URL, + service_account_impersonation_url=service_account_impersonation_url, + credential_source=cls.CREDENTIAL_SOURCE, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, + ) + @classmethod def make_mock_request( cls, @@ -230,6 +289,21 @@ def test_default_state(self): assert credentials.requires_scopes assert not credentials.quota_project_id + def test_nonworkforce_with_workforce_pool_user_project(self): + with pytest.raises(ValueError) as excinfo: + CredentialsImpl( + audience=self.AUDIENCE, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + assert excinfo.match( + "workforce_pool_user_project should not be set for non-workforce " + "pool credentials" + ) + def test_with_scopes(self): credentials = self.make_credentials() @@ -241,6 +315,23 @@ def test_with_scopes(self): assert scoped_credentials.has_scopes(["email"]) assert not scoped_credentials.requires_scopes + def test_with_scopes_workforce_pool(self): + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + assert not credentials.scopes + assert credentials.requires_scopes + + scoped_credentials = credentials.with_scopes(["email"]) + + assert scoped_credentials.has_scopes(["email"]) + assert not scoped_credentials.requires_scopes + assert ( + scoped_credentials.info.get("workforce_pool_user_project") + == self.WORKFORCE_POOL_USER_PROJECT + ) + def test_with_scopes_using_user_and_default_scopes(self): credentials = self.make_credentials() @@ -296,6 +387,7 @@ def test_with_scopes_full_options_propagated(self): quota_project_id=self.QUOTA_PROJECT_ID, scopes=["email"], default_scopes=["default2"], + workforce_pool_user_project=None, ) def test_with_quota_project(self): @@ -308,6 +400,22 @@ def test_with_quota_project(self): assert quota_project_creds.quota_project_id == "project-foo" + def test_with_quota_project_workforce_pool(self): + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + assert not credentials.scopes + assert not credentials.quota_project_id + + quota_project_creds = credentials.with_quota_project("project-foo") + + assert quota_project_creds.quota_project_id == "project-foo" + assert ( + quota_project_creds.info.get("workforce_pool_user_project") + == self.WORKFORCE_POOL_USER_PROJECT + ) + def test_with_quota_project_full_options_propagated(self): credentials = self.make_credentials( client_id=CLIENT_ID, @@ -336,6 +444,7 @@ def test_with_quota_project_full_options_propagated(self): quota_project_id="project-foo", scopes=self.SCOPES, default_scopes=["default1"], + workforce_pool_user_project=None, ) def test_with_invalid_impersonation_target_principal(self): @@ -359,6 +468,20 @@ def test_info(self): "credential_source": self.CREDENTIAL_SOURCE.copy(), } + def test_info_workforce_pool(self): + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + assert credentials.info == { + "type": "external_account", + "audience": self.WORKFORCE_AUDIENCE, + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": self.TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE.copy(), + "workforce_pool_user_project": self.WORKFORCE_POOL_USER_PROJECT, + } + def test_info_with_full_options(self): credentials = self.make_credentials( client_id=CLIENT_ID, @@ -391,36 +514,7 @@ def test_service_account_email_with_impersonation(self): assert credentials.service_account_email == SERVICE_ACCOUNT_EMAIL - @pytest.mark.parametrize( - "audience", - # Workload identity pool audiences or invalid workforce pool audiences. - [ - # Legacy K8s audience format. - "identitynamespace:1f12345:my_provider", - ( - "//iam.googleapis.com/projects/123456/locations/" - "global/workloadIdentityPools/pool-id/providers/" - "provider-id" - ), - ( - "//iam.googleapis.com/projects/123456/locations/" - "eu/workloadIdentityPools/pool-id/providers/" - "provider-id" - ), - # Pool ID with workforcePools string. - ( - "//iam.googleapis.com/projects/123456/locations/" - "global/workloadIdentityPools/workforcePools/providers/" - "provider-id" - ), - # Unrealistic / incorrect workforce pool audiences. - "//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", - "//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", - "//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", - "//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", - "//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", - ], - ) + @pytest.mark.parametrize("audience", TEST_NON_USER_AUDIENCES) def test_is_user_with_non_users(self, audience): credentials = CredentialsImpl( audience=audience, @@ -458,6 +552,43 @@ def test_is_user_with_users_and_impersonation(self, audience): # not a user. assert credentials.is_user is False + @pytest.mark.parametrize("audience", TEST_NON_USER_AUDIENCES) + def test_is_workforce_pool_with_non_users(self, audience): + credentials = CredentialsImpl( + audience=audience, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + ) + + assert credentials.is_workforce_pool is False + + @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES) + def test_is_workforce_pool_with_users(self, audience): + credentials = CredentialsImpl( + audience=audience, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + ) + + assert credentials.is_workforce_pool is True + + @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES) + def test_is_workforce_pool_with_users_and_impersonation(self, audience): + # Initialize the credentials with workforce audience and service account + # impersonation. + credentials = CredentialsImpl( + audience=audience, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + ) + + # Even though impersonation is used, is_workforce_pool should still return True. + assert credentials.is_workforce_pool is True + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test_refresh_without_client_auth_success(self, unused_utcnow): response = self.SUCCESS_RESPONSE.copy() @@ -485,6 +616,110 @@ def test_refresh_without_client_auth_success(self, unused_utcnow): assert not credentials.expired assert credentials.token == response["access_token"] + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_workforce_without_client_auth_success(self, unused_utcnow): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "options": urllib.parse.quote( + json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT}) + ), + } + request = self.make_mock_request(status=http.client.OK, data=response) + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_workforce_with_client_auth_success(self, unused_utcnow): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request(status=http.client.OK, data=response) + # Client Auth will have higher priority over workforce_pool_user_project. + credentials = self.make_workforce_pool_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_workforce_with_client_auth_and_no_workforce_project_success( + self, unused_utcnow + ): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request(status=http.client.OK, data=response) + # Client Auth will be sufficient for user project determination. + credentials = self.make_workforce_pool_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + workforce_pool_user_project=None, + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + def test_refresh_impersonation_without_client_auth_success(self): # Simulate service account access token expires in 2800 seconds. expire_time = ( @@ -549,6 +784,74 @@ def test_refresh_impersonation_without_client_auth_success(self): assert not credentials.expired assert credentials.token == impersonation_response["accessToken"] + def test_refresh_workforce_impersonation_without_client_auth_success(self): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + "options": urllib.parse.quote( + json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT}) + ), + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http.client.OK, + data=token_response, + impersonation_status=http.client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_workforce_pool_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1][1], + impersonation_headers, + impersonation_request_data, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + def test_refresh_without_client_auth_success_explicit_user_scopes_ignore_default_scopes( self, ): @@ -822,6 +1125,22 @@ def test_apply_without_quota_project_id(self): "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) } + def test_apply_workforce_without_quota_project_id(self): + headers = {} + request = self.make_mock_request( + status=http.client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) + } + def test_apply_impersonation_without_quota_project_id(self): expire_time = ( _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) @@ -926,6 +1245,31 @@ def test_before_request(self): "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), } + def test_before_request_workforce(self): + headers = {"other": "header-value"} + request = self.make_mock_request( + status=http.client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + # First call should call refresh, setting the token. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + } + + # Second call shouldn't call refresh. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + } + def test_before_request_impersonation(self): expire_time = ( _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) @@ -1091,6 +1435,17 @@ def test_project_number_determinable(self): assert credentials.project_number == self.PROJECT_NUMBER + def test_project_number_workforce(self): + credentials = CredentialsImpl( + audience=self.WORKFORCE_AUDIENCE, + subject_token_type=self.WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + assert credentials.project_number is None + def test_project_id_without_scopes(self): # Initialize credentials with no scopes. credentials = CredentialsImpl( @@ -1190,6 +1545,68 @@ def test_get_project_id_cloud_resource_manager_success(self): # No additional requests. assert len(request.call_args_list) == 3 + def test_workforce_pool_get_project_id_cloud_resource_manager_success(self): + # STS token exchange request/response. + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "scope": "scope1 scope2", + "options": urllib.parse.quote( + json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT}) + ), + } + # Initialize mock request to handle token exchange and cloud resource + # manager request. + request = self.make_mock_request( + status=http.client.OK, + data=self.SUCCESS_RESPONSE.copy(), + cloud_resource_manager_status=http.client.OK, + cloud_resource_manager_data=self.CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE, + ) + credentials = self.make_workforce_pool_credentials( + scopes=self.SCOPES, + quota_project_id=self.QUOTA_PROJECT_ID, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + # Expected project ID from cloud resource manager response should be returned. + project_id = credentials.get_project_id(request) + + assert project_id == self.PROJECT_ID + # 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + # In the process of getting project ID, an access token should be + # retrieved. + assert credentials.valid + assert not credentials.expired + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + # Verify cloud resource manager request parameters. + self.assert_resource_manager_request_kwargs( + request.call_args_list[1][1], + self.WORKFORCE_POOL_USER_PROJECT, + { + "x-goog-user-project": self.QUOTA_PROJECT_ID, + "authorization": "Bearer {}".format( + self.SUCCESS_RESPONSE["access_token"] + ), + }, + ) + + # Calling get_project_id again should return the cached project_id. + project_id = credentials.get_project_id(request) + + assert project_id == self.PROJECT_ID + # No additional requests. + assert len(request.call_args_list) == 2 + def test_get_project_id_cloud_resource_manager_error(self): # Simulate resource doesn't have sufficient permissions to access # cloud resource manager. diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index efe11b082..e90e2880d 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -53,6 +53,11 @@ TOKEN_URL = "https://sts.googleapis.com/v1/token" SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID" +WORKFORCE_AUDIENCE = ( + "//iam.googleapis.com/locations/global/workforcePools/POOL_ID/providers/PROVIDER_ID" +) +WORKFORCE_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:id_token" +WORKFORCE_POOL_USER_PROJECT = "WORKFORCE_POOL_USER_PROJECT_NUMBER" class TestCredentials(object): @@ -158,6 +163,7 @@ def assert_underlying_credentials_refresh( credential_data=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): """Utility to assert that a credentials are initialized with the expected attributes by calling refresh functionality and confirming response matches @@ -183,6 +189,10 @@ def assert_underlying_credentials_refresh( "subject_token": subject_token, "subject_token_type": subject_token_type, } + if workforce_pool_user_project: + token_request_data["options"] = urllib.parse.quote( + json.dumps({"userProject": workforce_pool_user_project}) + ) if service_account_impersonation_url: # Service account impersonation request/response. @@ -250,6 +260,8 @@ def assert_underlying_credentials_refresh( @classmethod def make_credentials( cls, + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, client_id=None, client_secret=None, quota_project_id=None, @@ -257,10 +269,11 @@ def make_credentials( default_scopes=None, service_account_impersonation_url=None, credential_source=None, + workforce_pool_user_project=None, ): return identity_pool.Credentials( - audience=AUDIENCE, - subject_token_type=SUBJECT_TOKEN_TYPE, + audience=audience, + subject_token_type=subject_token_type, token_url=TOKEN_URL, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, @@ -269,6 +282,7 @@ def make_credentials( quota_project_id=quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -297,6 +311,7 @@ def test_from_info_full_options(self, mock_init): client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=QUOTA_PROJECT_ID, + workforce_pool_user_project=None, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -321,6 +336,33 @@ def test_from_info_required_options_only(self, mock_init): client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=None, + workforce_pool_user_project=None, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_info_workforce_pool(self, mock_init): + credentials = identity_pool.Credentials.from_info( + { + "audience": WORKFORCE_AUDIENCE, + "subject_token_type": WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + "workforce_pool_user_project": WORKFORCE_POOL_USER_PROJECT, + } + ) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=None, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -350,6 +392,7 @@ def test_from_file_full_options(self, mock_init, tmpdir): client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=QUOTA_PROJECT_ID, + workforce_pool_user_project=None, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -375,6 +418,46 @@ def test_from_file_required_options_only(self, mock_init, tmpdir): client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=None, + workforce_pool_user_project=None, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_file_workforce_pool(self, mock_init, tmpdir): + info = { + "audience": WORKFORCE_AUDIENCE, + "subject_token_type": WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + "workforce_pool_user_project": WORKFORCE_POOL_USER_PROJECT, + } + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(info)) + credentials = identity_pool.Credentials.from_file(str(config_file)) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=None, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + def test_constructor_nonworkforce_with_workforce_pool_user_project(self): + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + audience=AUDIENCE, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + assert excinfo.match( + "workforce_pool_user_project should not be set for non-workforce " + "pool credentials" ) def test_constructor_invalid_options(self): @@ -430,6 +513,23 @@ def test_constructor_missing_subject_token_field_name(self): r"Missing subject_token_field_name for JSON credential_source format" ) + def test_info_with_workforce_pool_user_project(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL.copy(), + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + assert credentials.info == { + "type": "external_account", + "audience": WORKFORCE_AUDIENCE, + "subject_token_type": WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT_URL, + "workforce_pool_user_project": WORKFORCE_POOL_USER_PROJECT, + } + def test_info_with_file_credential_source(self): credentials = self.make_credentials( credential_source=self.CREDENTIAL_SOURCE_TEXT_URL.copy() @@ -557,6 +657,115 @@ def test_refresh_text_file_success_without_impersonation_ignore_default_scopes( default_scopes=["ignored"], ) + def test_refresh_workforce_success_with_client_auth_without_impersonation(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This will be ignored in favor of client auth. + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=None, + ) + + def test_refresh_workforce_success_with_client_auth_and_no_workforce_project(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This is not needed when client Auth is used. + workforce_pool_user_project=None, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=None, + ) + + def test_refresh_workforce_success_without_client_auth_without_impersonation(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=None, + client_secret=None, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This will not be ignored as client auth is not used. + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + def test_refresh_workforce_success_without_client_auth_with_impersonation(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=None, + client_secret=None, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This will not be ignored as client auth is not used. + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + def test_refresh_text_file_success_without_impersonation_use_default_scopes(self): credentials = self.make_credentials( client_id=CLIENT_ID,