From e3f4c7eac0417705553ef8259599ec29fc8ad9b4 Mon Sep 17 00:00:00 2001 From: Igor Bernstein Date: Thu, 20 May 2021 10:24:45 -0400 Subject: [PATCH] feat: enable pre-emptive async oauth token refreshes (#646) * feat: add pre-emptive async oauth token refreshes This is currently a rough sketch and should not be merged. I just wanted to get some feedback here. The current implementation of oauth refresh offloads the IO to a separate executor, however when a token expires, all calls to getRequestMetadata would hang until the token is refreshed. This PR is a rough sketch to improve the situation by adding a stale state to token. If a token is within a minute of expiration, the first request to notice this, will spawn a refresh on the executor and immediately return with the existing token. This avoids hourly latency spikes for grpc. The implementation uses guava's ListenableFutures to manage the async refresh state. Although the apis are marked BetaApi in guava 19, they are GA in guava 30 * The class introduces 3 distinct states: * Expired - the token is not usable * Stale - the token is useable, but its time to refresh * Fresh - token can be used without any extra effort * All of the funtionality of getRequestMetadata has been extracted to asyncFetch. The new function will check if the token is unfresh and schedule a refresh using the executor * asyncFetch uses ListenableFutures to wrap state: if the token is not expired, an immediate future is returned. If the token is expired the future of the refresh task is returned * A helper refreshAsync & finishRefreshAsync are also introduced. They schedule the actual refresh and propagate the side effects * To handle blocking invocation: the async functionality is re-used but a DirectExecutor is used. All ExecutionErrors are unwrapped. In most cases the stack trace will be preserved because of the DirectExecutor. However if the async & sync methods are interleaved, it's possible that a sync call will await an async refresh task. In this case the callers stacktrace will not be present. * minor doc * update broken test * prep for merging with master: The initial async refresh feature was implemented on top of 0.8, so now I'm backporting features to minimize the merge conflicts * in blocking mode, when a token is stale, only block the first caller and allow subsequent callers to use the stale token * use private monitor to minimize change noise * minor tweaks and test * format * fix ExternalAccountCredentials * fix boundary checks and add a few more tests * add another test for making sure that blocking stacktraces include the caller * address feedback * optimize for the common case * codestyle * use Date to calculate cache state to fix tests that mock access token * remove accidental double call Co-authored-by: Les Vogel --- .../oauth2/ExternalAccountCredentials.java | 22 + .../google/auth/oauth2/OAuth2Credentials.java | 386 ++++++++++++++---- .../oauth2/MockRequestMetadataCallback.java | 14 + .../auth/oauth2/MockTokenServerTransport.java | 33 +- .../auth/oauth2/OAuth2CredentialsTest.java | 331 ++++++++++++++- 5 files changed, 684 insertions(+), 102 deletions(-) diff --git a/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java b/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java index 9293d5855..5ded1a140 100644 --- a/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java @@ -35,6 +35,7 @@ import com.google.api.client.json.GenericJson; import com.google.api.client.json.JsonObjectParser; +import com.google.auth.RequestMetadataCallback; import com.google.auth.http.HttpTransportFactory; import com.google.auth.oauth2.AwsCredentials.AwsCredentialSource; import com.google.auth.oauth2.IdentityPoolCredentials.IdentityPoolCredentialSource; @@ -48,6 +49,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.Executor; import javax.annotation.Nullable; /** @@ -208,6 +210,26 @@ private ImpersonatedCredentials initializeImpersonatedCredentials() { .build(); } + @Override + public void getRequestMetadata( + URI uri, Executor executor, final RequestMetadataCallback callback) { + super.getRequestMetadata( + uri, + executor, + new RequestMetadataCallback() { + @Override + public void onSuccess(Map> metadata) { + metadata = addQuotaProjectIdToRequestMetadata(quotaProjectId, metadata); + callback.onSuccess(metadata); + } + + @Override + public void onFailure(Throwable exception) { + callback.onFailure(exception); + } + }); + } + @Override public Map> getRequestMetadata(URI uri) throws IOException { Map> requestMetadata = super.getRequestMetadata(uri); diff --git a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java index f22d3c74e..c1b11bdc2 100644 --- a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java @@ -31,37 +31,50 @@ package com.google.auth.oauth2; +import static java.util.concurrent.TimeUnit.MINUTES; + import com.google.api.client.util.Clock; import com.google.auth.Credentials; import com.google.auth.RequestMetadataCallback; import com.google.auth.http.AuthHttpConstants; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListenableFutureTask; +import com.google.common.util.concurrent.MoreExecutors; import java.io.IOException; import java.io.ObjectInputStream; +import java.io.Serializable; import java.net.URI; import java.util.ArrayList; -import java.util.Collections; import java.util.Date; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.ServiceLoader; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.Future; +import javax.annotation.Nullable; /** Base type for Credentials using OAuth2. */ public class OAuth2Credentials extends Credentials { private static final long serialVersionUID = 4556936364828217687L; - private static final long MINIMUM_TOKEN_MILLISECONDS = 60000L * 5L; - private static final Map> EMPTY_EXTRA_HEADERS = Collections.emptyMap(); + static final long MINIMUM_TOKEN_MILLISECONDS = MINUTES.toMillis(5); + static final long REFRESH_MARGIN_MILLISECONDS = MINIMUM_TOKEN_MILLISECONDS + MINUTES.toMillis(1); + private static final ImmutableMap> EMPTY_EXTRA_HEADERS = ImmutableMap.of(); // byte[] is serializable, so the lock variable can be final - private final Object lock = new byte[0]; - private Map> requestMetadata; - private AccessToken temporaryAccess; + @VisibleForTesting final Object lock = new byte[0]; + private volatile OAuthValue value = null; + @VisibleForTesting transient ListenableFutureTask refreshTask; // Change listeners are not serialized private transient List changeListeners; @@ -90,7 +103,7 @@ protected OAuth2Credentials() { */ protected OAuth2Credentials(AccessToken accessToken) { if (accessToken != null) { - useAccessToken(accessToken, EMPTY_EXTRA_HEADERS); + this.value = OAuthValue.create(accessToken, EMPTY_EXTRA_HEADERS); } } @@ -117,25 +130,21 @@ public boolean hasRequestMetadataOnly() { * @return The cached access token. */ public final AccessToken getAccessToken() { - return temporaryAccess; + OAuthValue localState = value; + if (localState != null) { + return localState.temporaryAccess; + } + return null; } @Override public void getRequestMetadata( final URI uri, Executor executor, final RequestMetadataCallback callback) { - Map> metadata; - synchronized (lock) { - if (shouldRefresh()) { - // The base class implementation will do a blocking get in the executor. - super.getRequestMetadata(uri, executor, callback); - return; - } - if (requestMetadata == null) { - throw new NullPointerException("cached requestMetadata"); - } - metadata = requestMetadata; - } - callback.onSuccess(metadata); + + Futures.addCallback( + asyncFetch(executor), + new FutureCallbackToMetadataCallbackAdapter(callback), + MoreExecutors.directExecutor()); } /** @@ -144,75 +153,188 @@ public void getRequestMetadata( */ @Override public Map> getRequestMetadata(URI uri) throws IOException { - synchronized (lock) { - if (shouldRefresh()) { - refresh(); - } - if (requestMetadata == null) { - throw new NullPointerException("requestMetadata"); - } - return requestMetadata; - } + return unwrapDirectFuture(asyncFetch(MoreExecutors.directExecutor())).requestMetadata; } - /** Refresh the token by discarding the cached token and metadata and requesting the new ones. */ + /** + * Request a new token regardless of the current token state. If the current token is not expired, + * it will still be returned during the refresh. + */ @Override public void refresh() throws IOException { + AsyncRefreshResult refreshResult = getOrCreateRefreshTask(); + refreshResult.executeIfNew(MoreExecutors.directExecutor()); + unwrapDirectFuture(refreshResult.task); + } + + /** + * Refresh these credentials only if they have expired or are expiring imminently. + * + * @throws IOException during token refresh. + */ + public void refreshIfExpired() throws IOException { + // asyncFetch will ensure that the token is refreshed + unwrapDirectFuture(asyncFetch(MoreExecutors.directExecutor())); + } + + /** + * Attempts to get a fresh token. + * + *

If a fresh token is already available, it will be immediately returned. Otherwise a refresh + * will be scheduled using the passed in executor. While a token is being freshed, a stale value + * will be returned. + */ + private ListenableFuture asyncFetch(Executor executor) { + AsyncRefreshResult refreshResult = null; + + // fast and common path: skip the lock if the token is fresh + // The inherent race condition here is a non-issue: even if the value gets replaced after the + // state check, the new token will still be fresh. + if (getState() == CacheState.FRESH) { + return Futures.immediateFuture(value); + } + + // Schedule a refresh as necessary synchronized (lock) { - requestMetadata = null; - temporaryAccess = null; - AccessToken accessToken = refreshAccessToken(); - if (accessToken == null) { - throw new NullPointerException("new access token"); + if (getState() != CacheState.FRESH) { + refreshResult = getOrCreateRefreshTask(); } - useAccessToken(accessToken, getAdditionalHeaders()); - if (changeListeners != null) { - for (CredentialsChangedListener listener : changeListeners) { - listener.onChanged(this); - } + } + // Execute the refresh if necessary. This should be done outside of the lock to avoid blocking + // metadata requests during a stale refresh. + if (refreshResult != null) { + refreshResult.executeIfNew(executor); + } + + synchronized (lock) { + // Immediately resolve the token token if its not expired, or wait for the refresh task to + // complete + if (getState() != CacheState.EXPIRED) { + return Futures.immediateFuture(value); + } else if (refreshResult != null) { + return refreshResult.task; + } else { + // Should never happen + return Futures.immediateFailedFuture( + new IllegalStateException("Credentials expired, but there is no task to refresh")); } } } /** - * Provide additional headers to return as request metadata. + * Atomically creates a single flight refresh token task. * - * @return additional headers + *

Only a single refresh task can be scheduled at a time. If there is an existing task, it will + * be returned for subsequent invocations. However if a new task is created, it is the + * responsibility of the caller to execute it. The task will clear the single flight slow upon + * completion. */ - protected Map> getAdditionalHeaders() { - return EMPTY_EXTRA_HEADERS; + private AsyncRefreshResult getOrCreateRefreshTask() { + synchronized (lock) { + if (refreshTask != null) { + return new AsyncRefreshResult(refreshTask, false); + } + + final ListenableFutureTask task = + ListenableFutureTask.create( + new Callable() { + @Override + public OAuthValue call() throws Exception { + return OAuthValue.create(refreshAccessToken(), getAdditionalHeaders()); + } + }); + + task.addListener( + new Runnable() { + @Override + public void run() { + finishRefreshAsync(task); + } + }, + MoreExecutors.directExecutor()); + + refreshTask = task; + + return new AsyncRefreshResult(refreshTask, true); + } } /** - * Refresh these credentials only if they have expired or are expiring imminently. + * Async callback for committing the result from a token refresh. * - * @throws IOException during token refresh. + *

The result will be stored, listeners are invoked and the single flight slot is cleared. */ - public void refreshIfExpired() throws IOException { + private void finishRefreshAsync(ListenableFuture finishedTask) { synchronized (lock) { - if (shouldRefresh()) { - refresh(); + try { + this.value = finishedTask.get(); + for (CredentialsChangedListener listener : changeListeners) { + listener.onChanged(this); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (Exception e) { + // noop + } finally { + if (this.refreshTask == finishedTask) { + this.refreshTask = null; + } } } } - // Must be called under lock - private void useAccessToken(AccessToken token, Map> additionalHeaders) { - this.temporaryAccess = token; - this.requestMetadata = - ImmutableMap.>builder() - .put( - AuthHttpConstants.AUTHORIZATION, - Collections.singletonList(OAuth2Utils.BEARER_PREFIX + token.getTokenValue())) - .putAll(additionalHeaders) - .build(); + /** + * Unwraps the value from the future. + * + *

Under most circumstances, the underlying future will already be resolved by the + * DirectExecutor. In those cases, the error stacktraces will be rooted in the caller's call tree. + * However, in some cases when async and sync usage is mixed, it's possible that a blocking call + * will await an async future. In those cases, the stacktrace will be orphaned and be rooted in a + * thread of whatever executor the async call used. This doesn't affect correctness and is + * extremely unlikely. + */ + private static T unwrapDirectFuture(Future future) throws IOException { + try { + return future.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while asynchronously refreshing the access token", e); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + throw (IOException) cause; + } else if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } else { + throw new IOException("Unexpected error refreshing access token", cause); + } + } } - // Must be called under lock - // requestMetadata will never be null if false is returned. - private boolean shouldRefresh() { - Long expiresIn = getExpiresInMilliseconds(); - return requestMetadata == null || expiresIn != null && expiresIn <= MINIMUM_TOKEN_MILLISECONDS; + /** Computes the effective credential state in relation to the current time. */ + private CacheState getState() { + OAuthValue localValue = value; + + if (localValue == null) { + return CacheState.EXPIRED; + } + Date expirationTime = localValue.temporaryAccess.getExpirationTime(); + + if (expirationTime == null) { + return CacheState.FRESH; + } + + long remainingMillis = expirationTime.getTime() - clock.currentTimeMillis(); + + if (remainingMillis <= MINIMUM_TOKEN_MILLISECONDS) { + return CacheState.EXPIRED; + } + + if (remainingMillis <= REFRESH_MARGIN_MILLISECONDS) { + return CacheState.STALE; + } + + return CacheState.FRESH; } /** @@ -233,6 +355,15 @@ public AccessToken refreshAccessToken() throws IOException { + " that supports refreshing."); } + /** + * Provide additional headers to return as request metadata. + * + * @return additional headers + */ + protected Map> getAdditionalHeaders() { + return EMPTY_EXTRA_HEADERS; + } + /** * Adds a listener that is notified when the Credentials data changes. * @@ -263,21 +394,6 @@ public final void removeChangeListener(CredentialsChangedListener listener) { } } - /** - * Return the remaining time the current access token will be valid, or null if there is no token - * or expiry information. Must be called under lock. - */ - private Long getExpiresInMilliseconds() { - if (temporaryAccess == null) { - return null; - } - Date expirationTime = temporaryAccess.getExpirationTime(); - if (expirationTime == null) { - return null; - } - return (expirationTime.getTime() - clock.currentTimeMillis()); - } - /** * Listener for changes to credentials. * @@ -300,15 +416,29 @@ public interface CredentialsChangedListener { @Override public int hashCode() { - return Objects.hash(requestMetadata, temporaryAccess); + return Objects.hashCode(value); } + @Nullable protected Map> getRequestMetadataInternal() { - return requestMetadata; + OAuthValue localValue = value; + if (localValue != null) { + return localValue.requestMetadata; + } + return null; } @Override public String toString() { + OAuthValue localValue = value; + + Map> requestMetadata = null; + AccessToken temporaryAccess = null; + + if (localValue != null) { + requestMetadata = localValue.requestMetadata; + temporaryAccess = localValue.temporaryAccess; + } return MoreObjects.toStringHelper(this) .add("requestMetadata", requestMetadata) .add("temporaryAccess", temporaryAccess) @@ -321,13 +451,13 @@ public boolean equals(Object obj) { return false; } OAuth2Credentials other = (OAuth2Credentials) obj; - return Objects.equals(this.requestMetadata, other.requestMetadata) - && Objects.equals(this.temporaryAccess, other.temporaryAccess); + return Objects.equals(this.value, other.value); } private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { input.defaultReadObject(); clock = Clock.SYSTEM; + refreshTask = null; } @SuppressWarnings("unchecked") @@ -351,6 +481,94 @@ public Builder toBuilder() { return new Builder(this); } + /** Stores an immutable snapshot of the accesstoken owned by {@link OAuth2Credentials} */ + static class OAuthValue implements Serializable { + private final AccessToken temporaryAccess; + private final Map> requestMetadata; + + static OAuthValue create(AccessToken token, Map> additionalHeaders) { + return new OAuthValue( + token, + ImmutableMap.>builder() + .put( + AuthHttpConstants.AUTHORIZATION, + ImmutableList.of(OAuth2Utils.BEARER_PREFIX + token.getTokenValue())) + .putAll(additionalHeaders) + .build()); + } + + private OAuthValue(AccessToken temporaryAccess, Map> requestMetadata) { + this.temporaryAccess = temporaryAccess; + this.requestMetadata = requestMetadata; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof OAuthValue)) { + return false; + } + OAuthValue other = (OAuthValue) obj; + return Objects.equals(this.requestMetadata, other.requestMetadata) + && Objects.equals(this.temporaryAccess, other.temporaryAccess); + } + + @Override + public int hashCode() { + return Objects.hash(temporaryAccess, requestMetadata); + } + } + + enum CacheState { + FRESH, + STALE, + EXPIRED; + } + + static class FutureCallbackToMetadataCallbackAdapter implements FutureCallback { + private final RequestMetadataCallback callback; + + public FutureCallbackToMetadataCallbackAdapter(RequestMetadataCallback callback) { + this.callback = callback; + } + + @Override + public void onSuccess(@Nullable OAuthValue value) { + callback.onSuccess(value.requestMetadata); + } + + @Override + public void onFailure(Throwable throwable) { + // refreshAccessToken will be invoked in an executor, so if it fails unwrap the underlying + // error + if (throwable instanceof ExecutionException) { + throwable = throwable.getCause(); + } + callback.onFailure(throwable); + } + } + + /** + * Result from {@link com.google.auth.oauth2.OAuth2Credentials#getOrCreateRefreshTask()}. + * + *

Contains the the refresh task and a flag indicating if the task is newly created. If the + * task is newly created, it is the caller's responsibility to execute it. + */ + static class AsyncRefreshResult { + private final ListenableFutureTask task; + private final boolean isNew; + + AsyncRefreshResult(ListenableFutureTask task, boolean isNew) { + this.task = task; + this.isNew = isNew; + } + + void executeIfNew(Executor executor) { + if (isNew) { + executor.execute(task); + } + } + } + public static class Builder { private AccessToken accessToken; diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java b/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java index db50328c3..691879eae 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java @@ -35,17 +35,20 @@ import com.google.common.base.Preconditions; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; /** Mock RequestMetadataCallback */ public final class MockRequestMetadataCallback implements RequestMetadataCallback { Map> metadata; Throwable exception; + CountDownLatch latch = new CountDownLatch(1); /** Called when metadata is successfully produced. */ @Override public void onSuccess(Map> metadata) { checkNotSet(); this.metadata = metadata; + latch.countDown(); } /** Called when metadata generation failed. */ @@ -53,11 +56,22 @@ public void onSuccess(Map> metadata) { public void onFailure(Throwable exception) { checkNotSet(); this.exception = exception; + latch.countDown(); } public void reset() { this.metadata = null; this.exception = null; + latch = new CountDownLatch(1); + } + + public Map> awaitResult() throws Throwable { + latch.await(); + if (exception != null) { + throw exception; + } else { + return metadata; + } } private void checkNotSet() { diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java b/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java index efa09d7c7..4b431d28a 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java @@ -42,6 +42,7 @@ import com.google.api.client.testing.http.MockLowLevelHttpRequest; import com.google.api.client.testing.http.MockLowLevelHttpResponse; import com.google.auth.TestUtils; +import com.google.common.util.concurrent.Futures; import java.io.IOException; import java.net.URI; import java.util.ArrayDeque; @@ -49,6 +50,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Queue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; /** Mock transport to simulate providing Google OAuth2 access tokens */ public class MockTokenServerTransport extends MockHttpTransport { @@ -62,8 +65,7 @@ public class MockTokenServerTransport extends MockHttpTransport { final Map codes = new HashMap(); URI tokenServerUri = OAuth2Utils.TOKEN_SERVER_URI; private IOException error; - private Queue responseErrorSequence = new ArrayDeque(); - private Queue responseSequence = new ArrayDeque(); + private final Queue> responseSequence = new ArrayDeque<>(); private int expiresInSeconds = 3600; public MockTokenServerTransport() {} @@ -103,16 +105,20 @@ public void setError(IOException error) { public void addResponseErrorSequence(IOException... errors) { for (IOException error : errors) { - responseErrorSequence.add(error); + responseSequence.add(Futures.immediateFailedFuture(error)); } } public void addResponseSequence(LowLevelHttpResponse... responses) { for (LowLevelHttpResponse response : responses) { - responseSequence.add(response); + responseSequence.add(Futures.immediateFuture(response)); } } + public void addResponseSequence(Future response) { + responseSequence.add(response); + } + public void setExpiresInSeconds(int expiresInSeconds) { this.expiresInSeconds = expiresInSeconds; } @@ -130,14 +136,19 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce return new MockLowLevelHttpRequest(url) { @Override public LowLevelHttpResponse execute() throws IOException { - IOException responseError = responseErrorSequence.poll(); - if (responseError != null) { - throw responseError; - } - LowLevelHttpResponse response = responseSequence.poll(); - if (response != null) { - return response; + + if (!responseSequence.isEmpty()) { + try { + return responseSequence.poll().get(); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + throw (IOException) cause; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Unexpectedly interrupted"); + } } + String content = this.getContentAsString(); Map query = TestUtils.parseQuery(content); String accessToken; diff --git a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java index 2f8c7d68a..f4eb16812 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java @@ -31,11 +31,13 @@ package com.google.auth.oauth2; +import static java.util.concurrent.TimeUnit.HOURS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -44,13 +46,31 @@ import com.google.auth.TestUtils; import com.google.auth.http.AuthHttpConstants; import com.google.auth.oauth2.GoogleCredentialsTest.MockTokenServerTransportFactory; +import com.google.auth.oauth2.OAuth2Credentials.OAuthValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFutureTask; +import com.google.common.util.concurrent.SettableFuture; import java.io.IOException; import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Date; import java.util.List; import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; import org.junit.Test; +import org.junit.function.ThrowingRunnable; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -64,6 +84,18 @@ public class OAuth2CredentialsTest extends BaseSerializationTest { private static final String ACCESS_TOKEN = "aashpFjkMkSJoj1xsli0H2eL5YsMgU_NKPY2TyGWY"; private static final URI CALL_URI = URI.create("http://googleapis.com/testapi/v1/foo"); + private ExecutorService realExecutor; + + @Before + public void setUp() { + realExecutor = Executors.newCachedThreadPool(); + } + + @After + public void tearDown() { + realExecutor.shutdown(); + } + @Test public void constructor_storesAccessToken() { OAuth2Credentials credentials = @@ -304,13 +336,14 @@ public void getRequestMetadata_async() throws IOException { } @Test - public void getRequestMetadata_async_refreshRace() throws IOException { + public void getRequestMetadata_async_refreshRace() + throws ExecutionException, InterruptedException { final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2"; MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET); transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1); TestClock clock = new TestClock(); - OAuth2Credentials credentials = + final OAuth2Credentials credentials = UserCredentials.newBuilder() .setClientId(CLIENT_ID) .setClientSecret(CLIENT_SECRET) @@ -326,14 +359,37 @@ public void getRequestMetadata_async_refreshRace() throws IOException { assertEquals(0, transportFactory.transport.buildRequestCount); assertNull(callback.metadata); - // Asynchronous task is scheduled, but beaten by another blocking get call. + // Asynchronous task is scheduled, and a blocking call follows it assertEquals(1, executor.numTasks()); - Map> metadata = credentials.getRequestMetadata(CALL_URI); - assertEquals(1, transportFactory.transport.buildRequestCount--); - TestUtils.assertContainsBearerToken(metadata, accessToken1); - // When the task is run, the cached data is used. + ExecutorService testExecutor = Executors.newFixedThreadPool(1); + + FutureTask>> blockingTask = + new FutureTask<>( + new Callable>>() { + @Override + public Map> call() throws Exception { + return credentials.getRequestMetadata(CALL_URI); + } + }); + + @SuppressWarnings("FutureReturnValueIgnored") + Future ignored = testExecutor.submit(blockingTask); + testExecutor.shutdown(); + + // give the blockingTask a chance to run + for (int i = 0; i < 10; i++) { + Thread.yield(); + } + + // blocking task is waiting on the async task to finish + assertFalse(blockingTask.isDone()); + assertEquals(0, transportFactory.transport.buildRequestCount); + + // When the task is run, the result is shared assertEquals(1, executor.runTasksExhaustively()); + assertEquals(1, transportFactory.transport.buildRequestCount--); + Map> metadata = blockingTask.get(); assertEquals(0, transportFactory.transport.buildRequestCount); assertEquals(metadata, callback.metadata); } @@ -348,6 +404,267 @@ public void getRequestMetadata_temporaryToken_hasToken() throws IOException { TestUtils.assertContainsBearerToken(metadata, ACCESS_TOKEN); } + @Test + public void getRequestMetadata_staleTemporaryToken() throws IOException, InterruptedException { + Calendar calendar = Calendar.getInstance(); + Date actualExpiration = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); + Date clientStale = calendar.getTime(); + + TestClock testClock = new TestClock(); + testClock.setCurrentTime(clientStale.getTime()); + + // Initialize credentials which are initially stale and set to refresh + final SettableFuture refreshedTokenFuture = SettableFuture.create(); + OAuth2Credentials creds = + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, actualExpiration)) { + @Override + public AccessToken refreshAccessToken() { + + try { + return refreshedTokenFuture.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + creds.clock = testClock; + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + + // Calls should return immediately with stale token + MockRequestMetadataCallback callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN); + TestUtils.assertContainsBearerToken(creds.getRequestMetadata(CALL_URI), ACCESS_TOKEN); + + // But a refresh task should be scheduled + synchronized (creds.lock) { + assertNotNull(creds.refreshTask); + } + + // Resolve the outstanding refresh + AccessToken refreshedToken = + new AccessToken( + "2/MkSJoj1xsli0AccessToken_NKPY2", + new Date(testClock.currentTimeMillis() + HOURS.toMillis(1))); + refreshedTokenFuture.set(refreshedToken); + + // The access token should available once the refresh thread completes + // However it will be populated asynchronously, so we need to wait until it propagates + // Wait at most 1 minute are 100ms intervals. It should never come close to this. + for (int i = 0; i < 600; i++) { + Map> requestMetadata = creds.getRequestMetadata(CALL_URI); + String s = requestMetadata.get(AuthHttpConstants.AUTHORIZATION).get(0); + if (s.contains(refreshedToken.getTokenValue())) { + break; + } + Thread.sleep(100); + } + + // Everything should return the new token + callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.metadata, refreshedToken.getTokenValue()); + TestUtils.assertContainsBearerToken( + creds.getRequestMetadata(CALL_URI), refreshedToken.getTokenValue()); + + // And the task slot is reset + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + } + + @Test + public void getRequestMetadata_staleTemporaryToken_expirationWaits() throws Throwable { + Calendar calendar = Calendar.getInstance(); + Date actualExpiration = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); + Date clientStale = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.MINIMUM_TOKEN_MILLISECONDS)); + Date clientExpired = calendar.getTime(); + + TestClock testClock = new TestClock(); + + // Initialize credentials which are initially stale and set to refresh + final SettableFuture refreshedTokenFuture = SettableFuture.create(); + OAuth2Credentials creds = + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, actualExpiration)) { + @Override + public AccessToken refreshAccessToken() { + + try { + return refreshedTokenFuture.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + creds.clock = testClock; + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + + // Calls should return immediately with stale token, but a refresh is scheduled + testClock.setCurrentTime(clientStale.getTime()); + MockRequestMetadataCallback callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN); + assertNotNull(creds.refreshTask); + ListenableFutureTask refreshTask = creds.refreshTask; + + // Fast forward to expiration, which will hang cause the callback to hang + testClock.setCurrentTime(clientExpired.getTime()); + // Make sure that the callback is hung (while giving it a chance to run) + for (int i = 0; i < 10; i++) { + Thread.sleep(10); + callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + assertNull(callback.metadata); + } + // The original refresh task should still be active + synchronized (creds.lock) { + assertSame(refreshTask, creds.refreshTask); + } + + // Resolve the outstanding refresh + AccessToken refreshedToken = + new AccessToken( + "2/MkSJoj1xsli0AccessToken_NKPY2", + new Date(testClock.currentTimeMillis() + HOURS.toMillis(1))); + refreshedTokenFuture.set(refreshedToken); + + // The access token should available once the refresh thread completes + TestUtils.assertContainsBearerToken( + creds.getRequestMetadata(CALL_URI), refreshedToken.getTokenValue()); + callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.awaitResult(), refreshedToken.getTokenValue()); + + // The refresh slot should be cleared + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + } + + @Test + public void getRequestMetadata_singleFlightErrorSharing() { + Calendar calendar = Calendar.getInstance(); + Date actualExpiration = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); + Date clientStale = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.MINIMUM_TOKEN_MILLISECONDS)); + Date clientExpired = calendar.getTime(); + + TestClock testClock = new TestClock(); + testClock.setCurrentTime(clientStale.getTime()); + + // Initialize credentials which are initially expired + final SettableFuture refreshErrorFuture = SettableFuture.create(); + final OAuth2Credentials creds = + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, clientExpired)) { + @Override + public AccessToken refreshAccessToken() { + RuntimeException injectedError; + + try { + injectedError = refreshErrorFuture.get(); + } catch (Exception e) { + throw new IllegalStateException("Unexpected error fetching injected error"); + } + throw injectedError; + } + }; + creds.clock = testClock; + + // Calls will hang waiting for the refresh + final MockRequestMetadataCallback callback1 = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback1); + + final Future>> blockingCall = + realExecutor.submit( + new Callable>>() { + @Override + public Map> call() throws Exception { + return creds.getRequestMetadata(CALL_URI); + } + }); + + RuntimeException error = new RuntimeException("fake error"); + refreshErrorFuture.set(error); + + // Get the error that getRequestMetadata(uri) created + Throwable actualBlockingError = + assertThrows( + ExecutionException.class, + new ThrowingRunnable() { + @Override + public void run() throws Throwable { + blockingCall.get(); + } + }) + .getCause(); + + assertEquals(error, actualBlockingError); + + RuntimeException actualAsyncError = + assertThrows( + RuntimeException.class, + new ThrowingRunnable() { + @Override + public void run() throws Throwable { + callback1.awaitResult(); + } + }); + assertEquals(error, actualAsyncError); + } + + @Test + public void getRequestMetadata_syncErrorsIncludeCallingStackframe() { + final OAuth2Credentials creds = + new OAuth2Credentials() { + @Override + public AccessToken refreshAccessToken() { + throw new RuntimeException("fake error"); + } + }; + + List expectedStacktrace = + new ArrayList<>(Arrays.asList(new Exception().getStackTrace())); + expectedStacktrace = expectedStacktrace.subList(1, expectedStacktrace.size()); + + AtomicReference actualError = new AtomicReference<>(); + try { + creds.getRequestMetadata(CALL_URI); + } catch (Exception refreshError) { + actualError.set(refreshError); + } + + List actualStacktrace = Arrays.asList(actualError.get().getStackTrace()); + actualStacktrace = + actualStacktrace.subList( + actualStacktrace.size() - expectedStacktrace.size(), actualStacktrace.size()); + + // ensure the remaining frames are identical + assertEquals(expectedStacktrace, actualStacktrace); + } + @Test public void refresh_refreshesToken() throws IOException { final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";