From c813d55a78053ecbec1a9640e6c9814da87319eb Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 30 Sep 2021 15:20:36 -0400 Subject: [PATCH] fix: timing of stale token refreshes on ComputeEngine (#749) * fix: timing of stale token refreshes on ComputeEngine ComputeEngine metadata server has its own token caching mechanism that will return a cached token until the last 5 minutes of its expiration. This has a negative interaction with stale token refreshes because stale token refresh kicks in T-6mins until T-5mins. This will cause every stale refresh to return the same stale token. This PR updates the timing for ComputeEngineCredentials to start a stale refresh at T-4mins and consider the token expired at T-3 mins. The implementation is a bit noisy because it includes a change OAuth2Credentials to make the thresholds configureable and now that we targeting java8, I migrated to using java8 time data types * fmt * fix test * fix test again * remove debug code * comments --- .../auth/oauth2/ComputeEngineCredentials.java | 11 ++ .../google/auth/oauth2/GoogleCredentials.java | 11 ++ .../google/auth/oauth2/OAuth2Credentials.java | 53 ++++-- .../oauth2/MockRequestMetadataCallback.java | 4 +- .../auth/oauth2/OAuth2CredentialsTest.java | 153 +++++++++++++----- 5 files changed, 181 insertions(+), 51 deletions(-) diff --git a/oauth2_http/java/com/google/auth/oauth2/ComputeEngineCredentials.java b/oauth2_http/java/com/google/auth/oauth2/ComputeEngineCredentials.java index 01988d1f4..92ab0cb34 100644 --- a/oauth2_http/java/com/google/auth/oauth2/ComputeEngineCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/ComputeEngineCredentials.java @@ -50,6 +50,7 @@ import java.io.ObjectInputStream; import java.net.SocketTimeoutException; import java.net.UnknownHostException; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -71,6 +72,14 @@ public class ComputeEngineCredentials extends GoogleCredentials implements ServiceAccountSigner, IdTokenProvider { + // Decrease timing margins on GCE. + // This is needed because GCE VMs maintain their own OAuth cache that expires T-5mins, attempting + // to refresh a token before then, will yield the same stale token. To enable pre-emptive + // refreshes, the margins must be shortened. This shouldn't cause problems since the clock skew + // on the VM and metadata proxy should be non-existent. + static final Duration COMPUTE_EXPIRATION_MARGIN = Duration.ofMinutes(3); + static final Duration COMPUTE_REFRESH_MARGIN = Duration.ofMinutes(4); + private static final Logger LOGGER = Logger.getLogger(ComputeEngineCredentials.class.getName()); static final String DEFAULT_METADATA_SERVER_URL = "http://metadata.google.internal"; @@ -116,6 +125,8 @@ private ComputeEngineCredentials( HttpTransportFactory transportFactory, Collection scopes, Collection defaultScopes) { + super(/* accessToken= */ null, COMPUTE_REFRESH_MARGIN, COMPUTE_EXPIRATION_MARGIN); + this.transportFactory = firstNonNull( transportFactory, diff --git a/oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java b/oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java index af7cc8ecd..235b4eeb1 100644 --- a/oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java @@ -40,6 +40,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -213,6 +214,16 @@ public GoogleCredentials(AccessToken accessToken) { super(accessToken); } + /** + * Constructor with explicit access token and refresh times + * + * @param accessToken initial or temporary access token + */ + protected GoogleCredentials( + AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) { + super(accessToken, refreshMargin, expirationMargin); + } + public static Builder newBuilder() { return new Builder(); } diff --git a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java index c1b11bdc2..3012990f3 100644 --- a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java @@ -31,14 +31,13 @@ package com.google.auth.oauth2; -import static java.util.concurrent.TimeUnit.MINUTES; - import com.google.api.client.util.Clock; import com.google.auth.Credentials; import com.google.auth.RequestMetadataCallback; import com.google.auth.http.AuthHttpConstants; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; @@ -51,6 +50,7 @@ import java.io.ObjectInputStream; import java.io.Serializable; import java.net.URI; +import java.time.Duration; import java.util.ArrayList; import java.util.Date; import java.util.List; @@ -67,10 +67,13 @@ public class OAuth2Credentials extends Credentials { private static final long serialVersionUID = 4556936364828217687L; - static final long MINIMUM_TOKEN_MILLISECONDS = MINUTES.toMillis(5); - static final long REFRESH_MARGIN_MILLISECONDS = MINIMUM_TOKEN_MILLISECONDS + MINUTES.toMillis(1); + static final Duration DEFAULT_EXPIRATION_MARGIN = Duration.ofMinutes(5); + static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(6); private static final ImmutableMap> EMPTY_EXTRA_HEADERS = ImmutableMap.of(); + private final Duration expirationMargin; + private final Duration refreshMargin; + // byte[] is serializable, so the lock variable can be final @VisibleForTesting final Object lock = new byte[0]; private volatile OAuthValue value = null; @@ -102,9 +105,20 @@ protected OAuth2Credentials() { * @param accessToken initial or temporary access token */ protected OAuth2Credentials(AccessToken accessToken) { + this(accessToken, DEFAULT_REFRESH_MARGIN, DEFAULT_EXPIRATION_MARGIN); + } + + protected OAuth2Credentials( + AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) { if (accessToken != null) { this.value = OAuthValue.create(accessToken, EMPTY_EXTRA_HEADERS); } + + this.refreshMargin = Preconditions.checkNotNull(refreshMargin, "refreshMargin"); + Preconditions.checkArgument(!refreshMargin.isNegative(), "refreshMargin can't be negative"); + this.expirationMargin = Preconditions.checkNotNull(expirationMargin, "expirationMargin"); + Preconditions.checkArgument( + !expirationMargin.isNegative(), "expirationMargin can't be negative"); } @Override @@ -324,13 +338,12 @@ private CacheState getState() { return CacheState.FRESH; } - long remainingMillis = expirationTime.getTime() - clock.currentTimeMillis(); - - if (remainingMillis <= MINIMUM_TOKEN_MILLISECONDS) { + Duration remaining = Duration.ofMillis(expirationTime.getTime() - clock.currentTimeMillis()); + if (remaining.compareTo(expirationMargin) <= 0) { return CacheState.EXPIRED; } - if (remainingMillis <= REFRESH_MARGIN_MILLISECONDS) { + if (remaining.compareTo(refreshMargin) <= 0) { return CacheState.STALE; } @@ -572,11 +585,15 @@ void executeIfNew(Executor executor) { public static class Builder { private AccessToken accessToken; + private Duration refreshMargin = DEFAULT_REFRESH_MARGIN; + private Duration expirationMargin = DEFAULT_EXPIRATION_MARGIN; protected Builder() {} protected Builder(OAuth2Credentials credentials) { this.accessToken = credentials.getAccessToken(); + this.refreshMargin = credentials.refreshMargin; + this.expirationMargin = credentials.expirationMargin; } public Builder setAccessToken(AccessToken token) { @@ -584,12 +601,30 @@ public Builder setAccessToken(AccessToken token) { return this; } + public Builder setRefreshMargin(Duration refreshMargin) { + this.refreshMargin = refreshMargin; + return this; + } + + public Duration getRefreshMargin() { + return refreshMargin; + } + + public Builder setExpirationMargin(Duration expirationMargin) { + this.expirationMargin = expirationMargin; + return this; + } + + public Duration getExpirationMargin() { + return expirationMargin; + } + public AccessToken getAccessToken() { return accessToken; } public OAuth2Credentials build() { - return new OAuth2Credentials(accessToken); + return new OAuth2Credentials(accessToken, refreshMargin, expirationMargin); } } } diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java b/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java index 691879eae..0d98a9a8e 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java @@ -39,8 +39,8 @@ /** Mock RequestMetadataCallback */ public final class MockRequestMetadataCallback implements RequestMetadataCallback { - Map> metadata; - Throwable exception; + volatile Map> metadata; + volatile Throwable exception; CountDownLatch latch = new CountDownLatch(1); /** Called when metadata is successfully produced. */ diff --git a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java index f4eb16812..dcfc77f04 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java @@ -49,14 +49,14 @@ import com.google.auth.oauth2.OAuth2Credentials.OAuthValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFutureTask; import com.google.common.util.concurrent.SettableFuture; import java.io.IOException; import java.net.URI; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; -import java.util.Calendar; import java.util.Date; import java.util.List; import java.util.Map; @@ -66,6 +66,8 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.FutureTask; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; @@ -103,6 +105,87 @@ public void constructor_storesAccessToken() { assertEquals(credentials.getAccessToken().getTokenValue(), ACCESS_TOKEN); } + @Test + public void constructor_overrideMargin() throws Throwable { + Duration staleMargin = Duration.ofMinutes(3); + Duration expirationMargin = Duration.ofMinutes(2); + + Instant actualExpiration = Instant.now(); + Instant clientStale = actualExpiration.minus(staleMargin); + Instant clientExpired = actualExpiration.minus(expirationMargin); + + AccessToken initialToken = new AccessToken(ACCESS_TOKEN, Date.from(actualExpiration)); + AtomicInteger refreshCount = new AtomicInteger(); + AtomicReference currentToken = new AtomicReference<>(initialToken); + + OAuth2Credentials credentials = + new OAuth2Credentials( + currentToken.get(), + /* refreshMargin= */ Duration.ofMinutes(3), + /* expirationMargin= */ Duration.ofMinutes(2)) { + @Override + public AccessToken refreshAccessToken() throws IOException { + refreshCount.incrementAndGet(); + // Inject delay to model network latency + // This is needed to make to deflake the stale tests: + // if the refresh is super quick, then a stale refresh will return the new token + try { + Thread.sleep(100); + } catch (InterruptedException e) { + throw new IOException(e); + } + + return currentToken.get(); + } + }; + + TestClock clock = new TestClock(); + credentials.clock = clock; + + // Rewind time to when the token is fresh + clock.setCurrentTime(clientStale.toEpochMilli() - 1); + MockRequestMetadataCallback callback = new MockRequestMetadataCallback(); + credentials.getRequestMetadata(CALL_URI, realExecutor, callback); + synchronized (credentials.lock) { + assertNull(credentials.refreshTask); + } + assertEquals(0, refreshCount.get()); + Map> lastMetadata = credentials.getRequestMetadata(CALL_URI); + + // Fast forward to when the token just turned STALE + clock.setCurrentTime(clientStale.toEpochMilli()); + currentToken.set(new AccessToken(ACCESS_TOKEN + "-1", Date.from(actualExpiration))); + callback.reset(); + credentials.getRequestMetadata(CALL_URI, realExecutor, callback); + assertEquals(lastMetadata, callback.awaitResult()); + waitForRefreshTaskCompletion(credentials); + assertEquals(1, refreshCount.get()); + lastMetadata = credentials.getRequestMetadata(CALL_URI); + refreshCount.set(0); + + // Fast forward to when the token turned STALE just before expiration + clock.setCurrentTime(clientExpired.toEpochMilli() - 1); + currentToken.set(new AccessToken(ACCESS_TOKEN + "-2", Date.from(actualExpiration))); + callback.reset(); + credentials.getRequestMetadata(CALL_URI, realExecutor, callback); + assertEquals(lastMetadata, callback.awaitResult()); + waitForRefreshTaskCompletion(credentials); + assertEquals(1, refreshCount.get()); + lastMetadata = credentials.getRequestMetadata(); + refreshCount.set(0); + + // Fast forward to expired + clock.setCurrentTime(clientExpired.toEpochMilli()); + AccessToken newToken = new AccessToken(ACCESS_TOKEN + "-3", Date.from(actualExpiration)); + currentToken.set(newToken); + callback.reset(); + credentials.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.awaitResult(), newToken.getTokenValue()); + assertEquals(1, refreshCount.get()); + waitForRefreshTaskCompletion(credentials); + lastMetadata = credentials.getRequestMetadata(); + } + @Test public void getAuthenticationType_returnsOAuth2() { OAuth2Credentials credentials = @@ -406,21 +489,16 @@ public void getRequestMetadata_temporaryToken_hasToken() throws IOException { @Test public void getRequestMetadata_staleTemporaryToken() throws IOException, InterruptedException { - Calendar calendar = Calendar.getInstance(); - Date actualExpiration = calendar.getTime(); - - calendar.setTime(actualExpiration); - calendar.add( - Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); - Date clientStale = calendar.getTime(); + Instant actualExpiration = Instant.now(); + Instant clientStale = actualExpiration.minus(OAuth2Credentials.DEFAULT_REFRESH_MARGIN); TestClock testClock = new TestClock(); - testClock.setCurrentTime(clientStale.getTime()); + testClock.setCurrentTime(clientStale.toEpochMilli()); // Initialize credentials which are initially stale and set to refresh final SettableFuture refreshedTokenFuture = SettableFuture.create(); OAuth2Credentials creds = - new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, actualExpiration)) { + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, Date.from(actualExpiration))) { @Override public AccessToken refreshAccessToken() { @@ -481,25 +559,16 @@ public AccessToken refreshAccessToken() { @Test public void getRequestMetadata_staleTemporaryToken_expirationWaits() throws Throwable { - Calendar calendar = Calendar.getInstance(); - Date actualExpiration = calendar.getTime(); - - calendar.setTime(actualExpiration); - calendar.add( - Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); - Date clientStale = calendar.getTime(); - - calendar.setTime(actualExpiration); - calendar.add( - Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.MINIMUM_TOKEN_MILLISECONDS)); - Date clientExpired = calendar.getTime(); + Instant actualExpiration = Instant.now(); + Instant clientStale = actualExpiration.minus(OAuth2Credentials.DEFAULT_REFRESH_MARGIN); + Instant clientExpired = actualExpiration.minus(OAuth2Credentials.DEFAULT_EXPIRATION_MARGIN); TestClock testClock = new TestClock(); // Initialize credentials which are initially stale and set to refresh final SettableFuture refreshedTokenFuture = SettableFuture.create(); OAuth2Credentials creds = - new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, actualExpiration)) { + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, Date.from(actualExpiration))) { @Override public AccessToken refreshAccessToken() { @@ -516,7 +585,7 @@ public AccessToken refreshAccessToken() { } // Calls should return immediately with stale token, but a refresh is scheduled - testClock.setCurrentTime(clientStale.getTime()); + testClock.setCurrentTime(clientStale.toEpochMilli()); MockRequestMetadataCallback callback = new MockRequestMetadataCallback(); creds.getRequestMetadata(CALL_URI, realExecutor, callback); TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN); @@ -524,7 +593,7 @@ public AccessToken refreshAccessToken() { ListenableFutureTask refreshTask = creds.refreshTask; // Fast forward to expiration, which will hang cause the callback to hang - testClock.setCurrentTime(clientExpired.getTime()); + testClock.setCurrentTime(clientExpired.toEpochMilli()); // Make sure that the callback is hung (while giving it a chance to run) for (int i = 0; i < 10; i++) { Thread.sleep(10); @@ -559,26 +628,17 @@ public AccessToken refreshAccessToken() { @Test public void getRequestMetadata_singleFlightErrorSharing() { - Calendar calendar = Calendar.getInstance(); - Date actualExpiration = calendar.getTime(); - - calendar.setTime(actualExpiration); - calendar.add( - Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); - Date clientStale = calendar.getTime(); - - calendar.setTime(actualExpiration); - calendar.add( - Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.MINIMUM_TOKEN_MILLISECONDS)); - Date clientExpired = calendar.getTime(); + Instant actualExpiration = Instant.now(); + Instant clientStale = actualExpiration.minus(OAuth2Credentials.DEFAULT_REFRESH_MARGIN); + Instant clientExpired = actualExpiration.minus(OAuth2Credentials.DEFAULT_EXPIRATION_MARGIN); TestClock testClock = new TestClock(); - testClock.setCurrentTime(clientStale.getTime()); + testClock.setCurrentTime(clientStale.toEpochMilli()); // Initialize credentials which are initially expired final SettableFuture refreshErrorFuture = SettableFuture.create(); final OAuth2Credentials creds = - new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, clientExpired)) { + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, Date.from(clientExpired))) { @Override public AccessToken refreshAccessToken() { RuntimeException injectedError; @@ -812,6 +872,19 @@ public void serialize() throws IOException, ClassNotFoundException { assertSame(deserializedCredentials.clock, Clock.SYSTEM); } + private void waitForRefreshTaskCompletion(OAuth2Credentials credentials) + throws TimeoutException, InterruptedException { + for (int i = 0; i < 100; i++) { + synchronized (credentials.lock) { + if (credentials.refreshTask == null) { + return; + } + } + Thread.sleep(100); + } + throw new TimeoutException("timed out waiting for refresh task to finish"); + } + private static class TestChangeListener implements OAuth2Credentials.CredentialsChangedListener { public AccessToken accessToken = null;