Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: timing of stale token refreshes on ComputeEngine #749

Merged
merged 8 commits into from Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -50,6 +50,7 @@
import java.io.ObjectInputStream;
import java.net.SocketTimeoutException;
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -71,6 +72,14 @@
public class ComputeEngineCredentials extends GoogleCredentials
implements ServiceAccountSigner, IdTokenProvider {

igorbernstein2 marked this conversation as resolved.
Show resolved Hide resolved
// Decrease timing margins on GCE.
// This is needed because GCE VMs maintain their own OAuth cache that expires T-5mins, attempting
// to refresh a token before then, will yield the same stale token. To enable pre-emptive
// refreshes, the margins must be shortened. This shouldn't cause problems since the clock skew
// on the VM and metadata proxy should be non-existent.
static final Duration COMPUTE_EXPIRATION_MARGIN = Duration.ofMinutes(3);
static final Duration COMPUTE_REFRESH_MARGIN = Duration.ofMinutes(4);

private static final Logger LOGGER = Logger.getLogger(ComputeEngineCredentials.class.getName());

static final String DEFAULT_METADATA_SERVER_URL = "http://metadata.google.internal";
Expand Down Expand Up @@ -116,6 +125,8 @@ private ComputeEngineCredentials(
HttpTransportFactory transportFactory,
Collection<String> scopes,
Collection<String> defaultScopes) {
super(/* accessToken= */ null, COMPUTE_REFRESH_MARGIN, COMPUTE_EXPIRATION_MARGIN);

this.transportFactory =
firstNonNull(
transportFactory,
Expand Down
11 changes: 11 additions & 0 deletions oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java
Expand Up @@ -40,6 +40,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -213,6 +214,16 @@ public GoogleCredentials(AccessToken accessToken) {
super(accessToken);
}

/**
* Constructor with explicit access token and refresh times
*
* @param accessToken initial or temporary access token
*/
protected GoogleCredentials(
AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) {
super(accessToken, refreshMargin, expirationMargin);
}

public static Builder newBuilder() {
return new Builder();
}
Expand Down
53 changes: 44 additions & 9 deletions oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java
Expand Up @@ -31,14 +31,13 @@

package com.google.auth.oauth2;

import static java.util.concurrent.TimeUnit.MINUTES;

import com.google.api.client.util.Clock;
import com.google.auth.Credentials;
import com.google.auth.RequestMetadataCallback;
import com.google.auth.http.AuthHttpConstants;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
Expand All @@ -51,6 +50,7 @@
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
Expand All @@ -67,10 +67,13 @@
public class OAuth2Credentials extends Credentials {

private static final long serialVersionUID = 4556936364828217687L;
static final long MINIMUM_TOKEN_MILLISECONDS = MINUTES.toMillis(5);
static final long REFRESH_MARGIN_MILLISECONDS = MINIMUM_TOKEN_MILLISECONDS + MINUTES.toMillis(1);
static final Duration DEFAULT_EXPIRATION_MARGIN = Duration.ofMinutes(5);
static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(6);
private static final ImmutableMap<String, List<String>> EMPTY_EXTRA_HEADERS = ImmutableMap.of();

private final Duration expirationMargin;
private final Duration refreshMargin;

// byte[] is serializable, so the lock variable can be final
@VisibleForTesting final Object lock = new byte[0];
private volatile OAuthValue value = null;
Expand Down Expand Up @@ -102,9 +105,20 @@ protected OAuth2Credentials() {
* @param accessToken initial or temporary access token
*/
protected OAuth2Credentials(AccessToken accessToken) {
this(accessToken, DEFAULT_REFRESH_MARGIN, DEFAULT_EXPIRATION_MARGIN);
}

protected OAuth2Credentials(
AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) {
if (accessToken != null) {
this.value = OAuthValue.create(accessToken, EMPTY_EXTRA_HEADERS);
}

this.refreshMargin = Preconditions.checkNotNull(refreshMargin, "refreshMargin");
Preconditions.checkArgument(!refreshMargin.isNegative(), "refreshMargin can't be negative");
this.expirationMargin = Preconditions.checkNotNull(expirationMargin, "expirationMargin");
Preconditions.checkArgument(
!expirationMargin.isNegative(), "expirationMargin can't be negative");
}

@Override
Expand Down Expand Up @@ -324,13 +338,12 @@ private CacheState getState() {
return CacheState.FRESH;
}

long remainingMillis = expirationTime.getTime() - clock.currentTimeMillis();

if (remainingMillis <= MINIMUM_TOKEN_MILLISECONDS) {
Duration remaining = Duration.ofMillis(expirationTime.getTime() - clock.currentTimeMillis());
if (remaining.compareTo(expirationMargin) <= 0) {
return CacheState.EXPIRED;
}

if (remainingMillis <= REFRESH_MARGIN_MILLISECONDS) {
if (remaining.compareTo(refreshMargin) <= 0) {
return CacheState.STALE;
}

Expand Down Expand Up @@ -572,24 +585,46 @@ void executeIfNew(Executor executor) {
public static class Builder {

private AccessToken accessToken;
private Duration refreshMargin = DEFAULT_REFRESH_MARGIN;
private Duration expirationMargin = DEFAULT_EXPIRATION_MARGIN;

protected Builder() {}

protected Builder(OAuth2Credentials credentials) {
this.accessToken = credentials.getAccessToken();
this.refreshMargin = credentials.refreshMargin;
this.expirationMargin = credentials.expirationMargin;
}

public Builder setAccessToken(AccessToken token) {
this.accessToken = token;
return this;
}

public Builder setRefreshMargin(Duration refreshMargin) {
this.refreshMargin = refreshMargin;
return this;
}

public Duration getRefreshMargin() {
return refreshMargin;
}

public Builder setExpirationMargin(Duration expirationMargin) {
this.expirationMargin = expirationMargin;
return this;
}

public Duration getExpirationMargin() {
return expirationMargin;
}

public AccessToken getAccessToken() {
return accessToken;
}

public OAuth2Credentials build() {
return new OAuth2Credentials(accessToken);
return new OAuth2Credentials(accessToken, refreshMargin, expirationMargin);
}
}
}
Expand Up @@ -39,8 +39,8 @@

/** Mock RequestMetadataCallback */
public final class MockRequestMetadataCallback implements RequestMetadataCallback {
Map<String, List<String>> metadata;
Throwable exception;
volatile Map<String, List<String>> metadata;
volatile Throwable exception;
CountDownLatch latch = new CountDownLatch(1);

/** Called when metadata is successfully produced. */
Expand Down