Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fallback to source creds expiration in downscoped tokens #805

Merged
merged 2 commits into from Jul 20, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions google/auth/downscoped.py
Expand Up @@ -479,8 +479,16 @@ def refresh(self, request):
additional_options=self._credential_access_boundary.to_json(),
)
self.token = response_data.get("access_token")
lifetime = datetime.timedelta(seconds=response_data.get("expires_in"))
self.expiry = now + lifetime
# For downscoping CAB flow, the STS endpoint may not return the expiration
# field for some flows. The generated downscoped token should always have
# the same expiration time as the source credentials. When no expires_in
# field is returned in the response, we can just get the expiration time
# from the source credentials.
if response_data.get("expires_in"):
lifetime = datetime.timedelta(seconds=response_data.get("expires_in"))
self.expiry = now + lifetime
else:
self.expiry = self._source_credentials.expiry

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
Expand Down
46 changes: 44 additions & 2 deletions tests/test_downscoped.py
Expand Up @@ -80,10 +80,11 @@


class SourceCredentials(credentials.Credentials):
def __init__(self, raise_error=False):
def __init__(self, raise_error=False, expires_in=3600):
super(SourceCredentials, self).__init__()
self._counter = 0
self._raise_error = raise_error
self._expires_in = expires_in

def refresh(self, request):
if self._raise_error:
Expand All @@ -93,7 +94,7 @@ def refresh(self, request):
now = _helpers.utcnow()
self._counter += 1
self.token = "ACCESS_TOKEN_{}".format(self._counter)
self.expiry = now + datetime.timedelta(seconds=3600)
self.expiry = now + datetime.timedelta(seconds=self._expires_in)


def make_availability_condition(expression, title=None, description=None):
Expand Down Expand Up @@ -539,6 +540,47 @@ def test_refresh(self, unused_utcnow):
# Confirm source credentials called with the same request instance.
wrapped_souce_cred_refresh.assert_called_with(request)

@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
def test_refresh_without_response_expires_in(self, unused_utcnow):
response = SUCCESS_RESPONSE.copy()
# Simulate the response is missing the expires_in field.
# The downscoped token expiration should match the source credentials
# expiration.
del response["expires_in"]
expected_expires_in = 1800
# Simulate the source credentials generates a token with 1800 second
# expiration time. The generated downscoped token should have the same
# expiration time.
source_credentials = SourceCredentials(expires_in=expected_expires_in)
expected_expiry = datetime.datetime.min + datetime.timedelta(
seconds=expected_expires_in
)
headers = {"Content-Type": "application/x-www-form-urlencoded"}
request_data = {
"grant_type": GRANT_TYPE,
"subject_token": "ACCESS_TOKEN_1",
"subject_token_type": SUBJECT_TOKEN_TYPE,
"requested_token_type": REQUESTED_TOKEN_TYPE,
"options": urllib.parse.quote(json.dumps(CREDENTIAL_ACCESS_BOUNDARY_JSON)),
}
request = self.make_mock_request(status=http_client.OK, data=response)
credentials = self.make_credentials(source_credentials=source_credentials)

# Spy on calls to source credentials refresh to confirm the expected request
# instance is used.
with mock.patch.object(
source_credentials, "refresh", wraps=source_credentials.refresh
) as wrapped_souce_cred_refresh:
credentials.refresh(request)

self.assert_request_kwargs(request.call_args[1], headers, request_data)
assert credentials.valid
assert credentials.expiry == expected_expiry
assert not credentials.expired
assert credentials.token == response["access_token"]
# Confirm source credentials called with the same request instance.
wrapped_souce_cred_refresh.assert_called_with(request)

def test_refresh_token_exchange_error(self):
request = self.make_mock_request(
status=http_client.BAD_REQUEST, data=ERROR_RESPONSE
Expand Down