Skip to content

Commit

Permalink
fix: fallback to source creds expiration in downscoped tokens (#805)
Browse files Browse the repository at this point in the history
For downscoping CAB flow, the STS endpoint may not return the expiration
field for certain source credentials. The generated downscoped token
should always have the same expiration time as the source credentials.
When no `expires_in` field is returned in the response, we can just get
the expiration time from the source credentials.

Co-authored-by: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com>
  • Loading branch information
bojeil-google and arithmetic1728 committed Jul 20, 2021
1 parent df9f2f9 commit dfad661
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 4 deletions.
12 changes: 10 additions & 2 deletions google/auth/downscoped.py
Expand Up @@ -479,8 +479,16 @@ def refresh(self, request):
additional_options=self._credential_access_boundary.to_json(),
)
self.token = response_data.get("access_token")
lifetime = datetime.timedelta(seconds=response_data.get("expires_in"))
self.expiry = now + lifetime
# For downscoping CAB flow, the STS endpoint may not return the expiration
# field for some flows. The generated downscoped token should always have
# the same expiration time as the source credentials. When no expires_in
# field is returned in the response, we can just get the expiration time
# from the source credentials.
if response_data.get("expires_in"):
lifetime = datetime.timedelta(seconds=response_data.get("expires_in"))
self.expiry = now + lifetime
else:
self.expiry = self._source_credentials.expiry

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
Expand Down
46 changes: 44 additions & 2 deletions tests/test_downscoped.py
Expand Up @@ -80,10 +80,11 @@


class SourceCredentials(credentials.Credentials):
def __init__(self, raise_error=False):
def __init__(self, raise_error=False, expires_in=3600):
super(SourceCredentials, self).__init__()
self._counter = 0
self._raise_error = raise_error
self._expires_in = expires_in

def refresh(self, request):
if self._raise_error:
Expand All @@ -93,7 +94,7 @@ def refresh(self, request):
now = _helpers.utcnow()
self._counter += 1
self.token = "ACCESS_TOKEN_{}".format(self._counter)
self.expiry = now + datetime.timedelta(seconds=3600)
self.expiry = now + datetime.timedelta(seconds=self._expires_in)


def make_availability_condition(expression, title=None, description=None):
Expand Down Expand Up @@ -539,6 +540,47 @@ def test_refresh(self, unused_utcnow):
# Confirm source credentials called with the same request instance.
wrapped_souce_cred_refresh.assert_called_with(request)

@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
def test_refresh_without_response_expires_in(self, unused_utcnow):
response = SUCCESS_RESPONSE.copy()
# Simulate the response is missing the expires_in field.
# The downscoped token expiration should match the source credentials
# expiration.
del response["expires_in"]
expected_expires_in = 1800
# Simulate the source credentials generates a token with 1800 second
# expiration time. The generated downscoped token should have the same
# expiration time.
source_credentials = SourceCredentials(expires_in=expected_expires_in)
expected_expiry = datetime.datetime.min + datetime.timedelta(
seconds=expected_expires_in
)
headers = {"Content-Type": "application/x-www-form-urlencoded"}
request_data = {
"grant_type": GRANT_TYPE,
"subject_token": "ACCESS_TOKEN_1",
"subject_token_type": SUBJECT_TOKEN_TYPE,
"requested_token_type": REQUESTED_TOKEN_TYPE,
"options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)),
}
request = self.make_mock_request(status=http_client.OK, data=response)
credentials = self.make_credentials(source_credentials=source_credentials)

# Spy on calls to source credentials refresh to confirm the expected request
# instance is used.
with mock.patch.object(
source_credentials, "refresh", wraps=source_credentials.refresh
) as wrapped_souce_cred_refresh:
credentials.refresh(request)

self.assert_request_kwargs(request.call_args[1], headers, request_data)
assert credentials.valid
assert credentials.expiry == expected_expiry
assert not credentials.expired
assert credentials.token == response["access_token"]
# Confirm source credentials called with the same request instance.
wrapped_souce_cred_refresh.assert_called_with(request)

def test_refresh_token_exchange_error(self):
request = self.make_mock_request(
status=http_client.BAD_REQUEST, data=ERROR_RESPONSE
Expand Down

0 comments on commit dfad661

Please sign in to comment.