diff --git a/google/auth/downscoped.py b/google/auth/downscoped.py index 800f2894c..96a4e6547 100644 --- a/google/auth/downscoped.py +++ b/google/auth/downscoped.py @@ -479,8 +479,16 @@ def refresh(self, request): additional_options=self._credential_access_boundary.to_json(), ) self.token = response_data.get("access_token") - lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) - self.expiry = now + lifetime + # For downscoping CAB flow, the STS endpoint may not return the expiration + # field for some flows. The generated downscoped token should always have + # the same expiration time as the source credentials. When no expires_in + # field is returned in the response, we can just get the expiration time + # from the source credentials. + if response_data.get("expires_in"): + lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) + self.expiry = now + lifetime + else: + self.expiry = self._source_credentials.expiry @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): diff --git a/tests/test_downscoped.py b/tests/test_downscoped.py index ac60e5b00..795ec2942 100644 --- a/tests/test_downscoped.py +++ b/tests/test_downscoped.py @@ -80,10 +80,11 @@ class SourceCredentials(credentials.Credentials): - def __init__(self, raise_error=False): + def __init__(self, raise_error=False, expires_in=3600): super(SourceCredentials, self).__init__() self._counter = 0 self._raise_error = raise_error + self._expires_in = expires_in def refresh(self, request): if self._raise_error: @@ -93,7 +94,7 @@ def refresh(self, request): now = _helpers.utcnow() self._counter += 1 self.token = "ACCESS_TOKEN_{}".format(self._counter) - self.expiry = now + datetime.timedelta(seconds=3600) + self.expiry = now + datetime.timedelta(seconds=self._expires_in) def make_availability_condition(expression, title=None, description=None): @@ -539,6 +540,47 @@ def test_refresh(self, unused_utcnow): # Confirm source credentials called with the same request instance. wrapped_souce_cred_refresh.assert_called_with(request) + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_without_response_expires_in(self, unused_utcnow): + response = SUCCESS_RESPONSE.copy() + # Simulate the response is missing the expires_in field. + # The downscoped token expiration should match the source credentials + # expiration. + del response["expires_in"] + expected_expires_in = 1800 + # Simulate the source credentials generates a token with 1800 second + # expiration time. The generated downscoped token should have the same + # expiration time. + source_credentials = SourceCredentials(expires_in=expected_expires_in) + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=expected_expires_in + ) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": GRANT_TYPE, + "subject_token": "ACCESS_TOKEN_1", + "subject_token_type": SUBJECT_TOKEN_TYPE, + "requested_token_type": REQUESTED_TOKEN_TYPE, + "options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)), + } + request = self.make_mock_request(status=http_client.OK, data=response) + credentials = self.make_credentials(source_credentials=source_credentials) + + # Spy on calls to source credentials refresh to confirm the expected request + # instance is used. + with mock.patch.object( + source_credentials, "refresh", wraps=source_credentials.refresh + ) as wrapped_souce_cred_refresh: + credentials.refresh(request) + + self.assert_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + # Confirm source credentials called with the same request instance. + wrapped_souce_cred_refresh.assert_called_with(request) + def test_refresh_token_exchange_error(self): request = self.make_mock_request( status=http_client.BAD_REQUEST, data=ERROR_RESPONSE