Skip to content

Commit

Permalink
fix: timing of stale token refreshes on ComputeEngine (#749)
Browse files Browse the repository at this point in the history
* fix: timing of stale token refreshes on ComputeEngine

ComputeEngine metadata server has its own token caching mechanism that will return a cached token until the last 5 minutes of its expiration. This has a negative interaction with stale token refreshes because stale token refresh kicks in T-6mins until T-5mins. This will cause every stale refresh to return the same stale token.

This PR updates the timing for ComputeEngineCredentials to start a stale refresh at T-4mins and consider the token expired at T-3 mins. The implementation is a bit noisy because it includes a change OAuth2Credentials to make the thresholds configureable and now that we targeting java8, I migrated to using java8 time data types

* fmt

* fix test

* fix test again

* remove debug code

* comments
  • Loading branch information
igorbernstein2 committed Sep 30, 2021
1 parent e1cbce1 commit c813d55
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 51 deletions.
Expand Up @@ -50,6 +50,7 @@
import java.io.ObjectInputStream;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -71,6 +72,14 @@
public class ComputeEngineCredentials extends GoogleCredentials
implements ServiceAccountSigner, IdTokenProvider {

// Decrease timing margins on GCE.
// This is needed because GCE VMs maintain their own OAuth cache that expires T-5mins, attempting
// to refresh a token before then, will yield the same stale token. To enable pre-emptive
// refreshes, the margins must be shortened. This shouldn't cause problems since the clock skew
// on the VM and metadata proxy should be non-existent.
static final Duration COMPUTE_EXPIRATION_MARGIN = Duration.ofMinutes(3);
static final Duration COMPUTE_REFRESH_MARGIN = Duration.ofMinutes(4);

private static final Logger LOGGER = Logger.getLogger(ComputeEngineCredentials.class.getName());

static final String DEFAULT_METADATA_SERVER_URL = "http://metadata.google.internal";
Expand Down Expand Up @@ -116,6 +125,8 @@ private ComputeEngineCredentials(
HttpTransportFactory transportFactory,
Collection<String> scopes,
Collection<String> defaultScopes) {
super(/* accessToken= */ null, COMPUTE_REFRESH_MARGIN, COMPUTE_EXPIRATION_MARGIN);

this.transportFactory =
firstNonNull(
transportFactory,
Expand Down
11 changes: 11 additions & 0 deletions oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java
Expand Up @@ -40,6 +40,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -213,6 +214,16 @@ public GoogleCredentials(AccessToken accessToken) {
super(accessToken);
}

/**
* Constructor with explicit access token and refresh times
*
* @param accessToken initial or temporary access token
*/
protected GoogleCredentials(
AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) {
super(accessToken, refreshMargin, expirationMargin);
}

public static Builder newBuilder() {
return new Builder();
}
Expand Down
53 changes: 44 additions & 9 deletions oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java
Expand Up @@ -31,14 +31,13 @@

package com.google.auth.oauth2;

import static java.util.concurrent.TimeUnit.MINUTES;

import com.google.api.client.util.Clock;
import com.google.auth.Credentials;
import com.google.auth.RequestMetadataCallback;
import com.google.auth.http.AuthHttpConstants;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
Expand All @@ -51,6 +50,7 @@
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
Expand All @@ -67,10 +67,13 @@
public class OAuth2Credentials extends Credentials {

private static final long serialVersionUID = 4556936364828217687L;
static final long MINIMUM_TOKEN_MILLISECONDS = MINUTES.toMillis(5);
static final long REFRESH_MARGIN_MILLISECONDS = MINIMUM_TOKEN_MILLISECONDS + MINUTES.toMillis(1);
static final Duration DEFAULT_EXPIRATION_MARGIN = Duration.ofMinutes(5);
static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(6);
private static final ImmutableMap<String, List<String>> EMPTY_EXTRA_HEADERS = ImmutableMap.of();

private final Duration expirationMargin;
private final Duration refreshMargin;

// byte[] is serializable, so the lock variable can be final
@VisibleForTesting final Object lock = new byte[0];
private volatile OAuthValue value = null;
Expand Down Expand Up @@ -102,9 +105,20 @@ protected OAuth2Credentials() {
* @param accessToken initial or temporary access token
*/
protected OAuth2Credentials(AccessToken accessToken) {
this(accessToken, DEFAULT_REFRESH_MARGIN, DEFAULT_EXPIRATION_MARGIN);
}

protected OAuth2Credentials(
AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) {
if (accessToken != null) {
this.value = OAuthValue.create(accessToken, EMPTY_EXTRA_HEADERS);
}

this.refreshMargin = Preconditions.checkNotNull(refreshMargin, "refreshMargin");
Preconditions.checkArgument(!refreshMargin.isNegative(), "refreshMargin can't be negative");
this.expirationMargin = Preconditions.checkNotNull(expirationMargin, "expirationMargin");
Preconditions.checkArgument(
!expirationMargin.isNegative(), "expirationMargin can't be negative");
}

@Override
Expand Down Expand Up @@ -324,13 +338,12 @@ private CacheState getState() {
return CacheState.FRESH;
}

long remainingMillis = expirationTime.getTime() - clock.currentTimeMillis();

if (remainingMillis <= MINIMUM_TOKEN_MILLISECONDS) {
Duration remaining = Duration.ofMillis(expirationTime.getTime() - clock.currentTimeMillis());
if (remaining.compareTo(expirationMargin) <= 0) {
return CacheState.EXPIRED;
}

if (remainingMillis <= REFRESH_MARGIN_MILLISECONDS) {
if (remaining.compareTo(refreshMargin) <= 0) {
return CacheState.STALE;
}

Expand Down Expand Up @@ -572,24 +585,46 @@ void executeIfNew(Executor executor) {
public static class Builder {

private AccessToken accessToken;
private Duration refreshMargin = DEFAULT_REFRESH_MARGIN;
private Duration expirationMargin = DEFAULT_EXPIRATION_MARGIN;

protected Builder() {}

protected Builder(OAuth2Credentials credentials) {
this.accessToken = credentials.getAccessToken();
this.refreshMargin = credentials.refreshMargin;
this.expirationMargin = credentials.expirationMargin;
}

public Builder setAccessToken(AccessToken token) {
this.accessToken = token;
return this;
}

public Builder setRefreshMargin(Duration refreshMargin) {
this.refreshMargin = refreshMargin;
return this;
}

public Duration getRefreshMargin() {
return refreshMargin;
}

public Builder setExpirationMargin(Duration expirationMargin) {
this.expirationMargin = expirationMargin;
return this;
}

public Duration getExpirationMargin() {
return expirationMargin;
}

public AccessToken getAccessToken() {
return accessToken;
}

public OAuth2Credentials build() {
return new OAuth2Credentials(accessToken);
return new OAuth2Credentials(accessToken, refreshMargin, expirationMargin);
}
}
}
Expand Up @@ -39,8 +39,8 @@

/** Mock RequestMetadataCallback */
public final class MockRequestMetadataCallback implements RequestMetadataCallback {
Map<String, List<String>> metadata;
Throwable exception;
volatile Map<String, List<String>> metadata;
volatile Throwable exception;
CountDownLatch latch = new CountDownLatch(1);

/** Called when metadata is successfully produced. */
Expand Down

0 comments on commit c813d55

Please sign in to comment.