diff --git a/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java b/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java index 9293d5855..5ded1a140 100644 --- a/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/ExternalAccountCredentials.java @@ -35,6 +35,7 @@ import com.google.api.client.json.GenericJson; import com.google.api.client.json.JsonObjectParser; +import com.google.auth.RequestMetadataCallback; import com.google.auth.http.HttpTransportFactory; import com.google.auth.oauth2.AwsCredentials.AwsCredentialSource; import com.google.auth.oauth2.IdentityPoolCredentials.IdentityPoolCredentialSource; @@ -48,6 +49,7 @@ import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.Executor; import javax.annotation.Nullable; /** @@ -208,6 +210,26 @@ private ImpersonatedCredentials initializeImpersonatedCredentials() { .build(); } + @Override + public void getRequestMetadata( + URI uri, Executor executor, final RequestMetadataCallback callback) { + super.getRequestMetadata( + uri, + executor, + new RequestMetadataCallback() { + @Override + public void onSuccess(Map> metadata) { + metadata = addQuotaProjectIdToRequestMetadata(quotaProjectId, metadata); + callback.onSuccess(metadata); + } + + @Override + public void onFailure(Throwable exception) { + callback.onFailure(exception); + } + }); + } + @Override public Map> getRequestMetadata(URI uri) throws IOException { Map> requestMetadata = super.getRequestMetadata(uri); diff --git a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java index f22d3c74e..c1b11bdc2 100644 --- a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java @@ -31,37 +31,50 @@ package com.google.auth.oauth2; +import static java.util.concurrent.TimeUnit.MINUTES; + import com.google.api.client.util.Clock; import com.google.auth.Credentials; import com.google.auth.RequestMetadataCallback; import com.google.auth.http.AuthHttpConstants; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListenableFutureTask; +import com.google.common.util.concurrent.MoreExecutors; import java.io.IOException; import java.io.ObjectInputStream; +import java.io.Serializable; import java.net.URI; import java.util.ArrayList; -import java.util.Collections; import java.util.Date; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.ServiceLoader; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.Future; +import javax.annotation.Nullable; /** Base type for Credentials using OAuth2. */ public class OAuth2Credentials extends Credentials { private static final long serialVersionUID = 4556936364828217687L; - private static final long MINIMUM_TOKEN_MILLISECONDS = 60000L * 5L; - private static final Map> EMPTY_EXTRA_HEADERS = Collections.emptyMap(); + static final long MINIMUM_TOKEN_MILLISECONDS = MINUTES.toMillis(5); + static final long REFRESH_MARGIN_MILLISECONDS = MINIMUM_TOKEN_MILLISECONDS + MINUTES.toMillis(1); + private static final ImmutableMap> EMPTY_EXTRA_HEADERS = ImmutableMap.of(); // byte[] is serializable, so the lock variable can be final - private final Object lock = new byte[0]; - private Map> requestMetadata; - private AccessToken temporaryAccess; + @VisibleForTesting final Object lock = new byte[0]; + private volatile OAuthValue value = null; + @VisibleForTesting transient ListenableFutureTask refreshTask; // Change listeners are not serialized private transient List changeListeners; @@ -90,7 +103,7 @@ protected OAuth2Credentials() { */ protected OAuth2Credentials(AccessToken accessToken) { if (accessToken != null) { - useAccessToken(accessToken, EMPTY_EXTRA_HEADERS); + this.value = OAuthValue.create(accessToken, EMPTY_EXTRA_HEADERS); } } @@ -117,25 +130,21 @@ public boolean hasRequestMetadataOnly() { * @return The cached access token. */ public final AccessToken getAccessToken() { - return temporaryAccess; + OAuthValue localState = value; + if (localState != null) { + return localState.temporaryAccess; + } + return null; } @Override public void getRequestMetadata( final URI uri, Executor executor, final RequestMetadataCallback callback) { - Map> metadata; - synchronized (lock) { - if (shouldRefresh()) { - // The base class implementation will do a blocking get in the executor. - super.getRequestMetadata(uri, executor, callback); - return; - } - if (requestMetadata == null) { - throw new NullPointerException("cached requestMetadata"); - } - metadata = requestMetadata; - } - callback.onSuccess(metadata); + + Futures.addCallback( + asyncFetch(executor), + new FutureCallbackToMetadataCallbackAdapter(callback), + MoreExecutors.directExecutor()); } /** @@ -144,75 +153,188 @@ public void getRequestMetadata( */ @Override public Map> getRequestMetadata(URI uri) throws IOException { - synchronized (lock) { - if (shouldRefresh()) { - refresh(); - } - if (requestMetadata == null) { - throw new NullPointerException("requestMetadata"); - } - return requestMetadata; - } + return unwrapDirectFuture(asyncFetch(MoreExecutors.directExecutor())).requestMetadata; } - /** Refresh the token by discarding the cached token and metadata and requesting the new ones. */ + /** + * Request a new token regardless of the current token state. If the current token is not expired, + * it will still be returned during the refresh. + */ @Override public void refresh() throws IOException { + AsyncRefreshResult refreshResult = getOrCreateRefreshTask(); + refreshResult.executeIfNew(MoreExecutors.directExecutor()); + unwrapDirectFuture(refreshResult.task); + } + + /** + * Refresh these credentials only if they have expired or are expiring imminently. + * + * @throws IOException during token refresh. + */ + public void refreshIfExpired() throws IOException { + // asyncFetch will ensure that the token is refreshed + unwrapDirectFuture(asyncFetch(MoreExecutors.directExecutor())); + } + + /** + * Attempts to get a fresh token. + * + *

If a fresh token is already available, it will be immediately returned. Otherwise a refresh + * will be scheduled using the passed in executor. While a token is being freshed, a stale value + * will be returned. + */ + private ListenableFuture asyncFetch(Executor executor) { + AsyncRefreshResult refreshResult = null; + + // fast and common path: skip the lock if the token is fresh + // The inherent race condition here is a non-issue: even if the value gets replaced after the + // state check, the new token will still be fresh. + if (getState() == CacheState.FRESH) { + return Futures.immediateFuture(value); + } + + // Schedule a refresh as necessary synchronized (lock) { - requestMetadata = null; - temporaryAccess = null; - AccessToken accessToken = refreshAccessToken(); - if (accessToken == null) { - throw new NullPointerException("new access token"); + if (getState() != CacheState.FRESH) { + refreshResult = getOrCreateRefreshTask(); } - useAccessToken(accessToken, getAdditionalHeaders()); - if (changeListeners != null) { - for (CredentialsChangedListener listener : changeListeners) { - listener.onChanged(this); - } + } + // Execute the refresh if necessary. This should be done outside of the lock to avoid blocking + // metadata requests during a stale refresh. + if (refreshResult != null) { + refreshResult.executeIfNew(executor); + } + + synchronized (lock) { + // Immediately resolve the token token if its not expired, or wait for the refresh task to + // complete + if (getState() != CacheState.EXPIRED) { + return Futures.immediateFuture(value); + } else if (refreshResult != null) { + return refreshResult.task; + } else { + // Should never happen + return Futures.immediateFailedFuture( + new IllegalStateException("Credentials expired, but there is no task to refresh")); } } } /** - * Provide additional headers to return as request metadata. + * Atomically creates a single flight refresh token task. * - * @return additional headers + *

Only a single refresh task can be scheduled at a time. If there is an existing task, it will + * be returned for subsequent invocations. However if a new task is created, it is the + * responsibility of the caller to execute it. The task will clear the single flight slow upon + * completion. */ - protected Map> getAdditionalHeaders() { - return EMPTY_EXTRA_HEADERS; + private AsyncRefreshResult getOrCreateRefreshTask() { + synchronized (lock) { + if (refreshTask != null) { + return new AsyncRefreshResult(refreshTask, false); + } + + final ListenableFutureTask task = + ListenableFutureTask.create( + new Callable() { + @Override + public OAuthValue call() throws Exception { + return OAuthValue.create(refreshAccessToken(), getAdditionalHeaders()); + } + }); + + task.addListener( + new Runnable() { + @Override + public void run() { + finishRefreshAsync(task); + } + }, + MoreExecutors.directExecutor()); + + refreshTask = task; + + return new AsyncRefreshResult(refreshTask, true); + } } /** - * Refresh these credentials only if they have expired or are expiring imminently. + * Async callback for committing the result from a token refresh. * - * @throws IOException during token refresh. + *

The result will be stored, listeners are invoked and the single flight slot is cleared. */ - public void refreshIfExpired() throws IOException { + private void finishRefreshAsync(ListenableFuture finishedTask) { synchronized (lock) { - if (shouldRefresh()) { - refresh(); + try { + this.value = finishedTask.get(); + for (CredentialsChangedListener listener : changeListeners) { + listener.onChanged(this); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (Exception e) { + // noop + } finally { + if (this.refreshTask == finishedTask) { + this.refreshTask = null; + } } } } - // Must be called under lock - private void useAccessToken(AccessToken token, Map> additionalHeaders) { - this.temporaryAccess = token; - this.requestMetadata = - ImmutableMap.>builder() - .put( - AuthHttpConstants.AUTHORIZATION, - Collections.singletonList(OAuth2Utils.BEARER_PREFIX + token.getTokenValue())) - .putAll(additionalHeaders) - .build(); + /** + * Unwraps the value from the future. + * + *

Under most circumstances, the underlying future will already be resolved by the + * DirectExecutor. In those cases, the error stacktraces will be rooted in the caller's call tree. + * However, in some cases when async and sync usage is mixed, it's possible that a blocking call + * will await an async future. In those cases, the stacktrace will be orphaned and be rooted in a + * thread of whatever executor the async call used. This doesn't affect correctness and is + * extremely unlikely. + */ + private static T unwrapDirectFuture(Future future) throws IOException { + try { + return future.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while asynchronously refreshing the access token", e); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + throw (IOException) cause; + } else if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } else { + throw new IOException("Unexpected error refreshing access token", cause); + } + } } - // Must be called under lock - // requestMetadata will never be null if false is returned. - private boolean shouldRefresh() { - Long expiresIn = getExpiresInMilliseconds(); - return requestMetadata == null || expiresIn != null && expiresIn <= MINIMUM_TOKEN_MILLISECONDS; + /** Computes the effective credential state in relation to the current time. */ + private CacheState getState() { + OAuthValue localValue = value; + + if (localValue == null) { + return CacheState.EXPIRED; + } + Date expirationTime = localValue.temporaryAccess.getExpirationTime(); + + if (expirationTime == null) { + return CacheState.FRESH; + } + + long remainingMillis = expirationTime.getTime() - clock.currentTimeMillis(); + + if (remainingMillis <= MINIMUM_TOKEN_MILLISECONDS) { + return CacheState.EXPIRED; + } + + if (remainingMillis <= REFRESH_MARGIN_MILLISECONDS) { + return CacheState.STALE; + } + + return CacheState.FRESH; } /** @@ -233,6 +355,15 @@ public AccessToken refreshAccessToken() throws IOException { + " that supports refreshing."); } + /** + * Provide additional headers to return as request metadata. + * + * @return additional headers + */ + protected Map> getAdditionalHeaders() { + return EMPTY_EXTRA_HEADERS; + } + /** * Adds a listener that is notified when the Credentials data changes. * @@ -263,21 +394,6 @@ public final void removeChangeListener(CredentialsChangedListener listener) { } } - /** - * Return the remaining time the current access token will be valid, or null if there is no token - * or expiry information. Must be called under lock. - */ - private Long getExpiresInMilliseconds() { - if (temporaryAccess == null) { - return null; - } - Date expirationTime = temporaryAccess.getExpirationTime(); - if (expirationTime == null) { - return null; - } - return (expirationTime.getTime() - clock.currentTimeMillis()); - } - /** * Listener for changes to credentials. * @@ -300,15 +416,29 @@ public interface CredentialsChangedListener { @Override public int hashCode() { - return Objects.hash(requestMetadata, temporaryAccess); + return Objects.hashCode(value); } + @Nullable protected Map> getRequestMetadataInternal() { - return requestMetadata; + OAuthValue localValue = value; + if (localValue != null) { + return localValue.requestMetadata; + } + return null; } @Override public String toString() { + OAuthValue localValue = value; + + Map> requestMetadata = null; + AccessToken temporaryAccess = null; + + if (localValue != null) { + requestMetadata = localValue.requestMetadata; + temporaryAccess = localValue.temporaryAccess; + } return MoreObjects.toStringHelper(this) .add("requestMetadata", requestMetadata) .add("temporaryAccess", temporaryAccess) @@ -321,13 +451,13 @@ public boolean equals(Object obj) { return false; } OAuth2Credentials other = (OAuth2Credentials) obj; - return Objects.equals(this.requestMetadata, other.requestMetadata) - && Objects.equals(this.temporaryAccess, other.temporaryAccess); + return Objects.equals(this.value, other.value); } private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { input.defaultReadObject(); clock = Clock.SYSTEM; + refreshTask = null; } @SuppressWarnings("unchecked") @@ -351,6 +481,94 @@ public Builder toBuilder() { return new Builder(this); } + /** Stores an immutable snapshot of the accesstoken owned by {@link OAuth2Credentials} */ + static class OAuthValue implements Serializable { + private final AccessToken temporaryAccess; + private final Map> requestMetadata; + + static OAuthValue create(AccessToken token, Map> additionalHeaders) { + return new OAuthValue( + token, + ImmutableMap.>builder() + .put( + AuthHttpConstants.AUTHORIZATION, + ImmutableList.of(OAuth2Utils.BEARER_PREFIX + token.getTokenValue())) + .putAll(additionalHeaders) + .build()); + } + + private OAuthValue(AccessToken temporaryAccess, Map> requestMetadata) { + this.temporaryAccess = temporaryAccess; + this.requestMetadata = requestMetadata; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof OAuthValue)) { + return false; + } + OAuthValue other = (OAuthValue) obj; + return Objects.equals(this.requestMetadata, other.requestMetadata) + && Objects.equals(this.temporaryAccess, other.temporaryAccess); + } + + @Override + public int hashCode() { + return Objects.hash(temporaryAccess, requestMetadata); + } + } + + enum CacheState { + FRESH, + STALE, + EXPIRED; + } + + static class FutureCallbackToMetadataCallbackAdapter implements FutureCallback { + private final RequestMetadataCallback callback; + + public FutureCallbackToMetadataCallbackAdapter(RequestMetadataCallback callback) { + this.callback = callback; + } + + @Override + public void onSuccess(@Nullable OAuthValue value) { + callback.onSuccess(value.requestMetadata); + } + + @Override + public void onFailure(Throwable throwable) { + // refreshAccessToken will be invoked in an executor, so if it fails unwrap the underlying + // error + if (throwable instanceof ExecutionException) { + throwable = throwable.getCause(); + } + callback.onFailure(throwable); + } + } + + /** + * Result from {@link com.google.auth.oauth2.OAuth2Credentials#getOrCreateRefreshTask()}. + * + *

Contains the the refresh task and a flag indicating if the task is newly created. If the + * task is newly created, it is the caller's responsibility to execute it. + */ + static class AsyncRefreshResult { + private final ListenableFutureTask task; + private final boolean isNew; + + AsyncRefreshResult(ListenableFutureTask task, boolean isNew) { + this.task = task; + this.isNew = isNew; + } + + void executeIfNew(Executor executor) { + if (isNew) { + executor.execute(task); + } + } + } + public static class Builder { private AccessToken accessToken; diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java b/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java index db50328c3..691879eae 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java @@ -35,17 +35,20 @@ import com.google.common.base.Preconditions; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; /** Mock RequestMetadataCallback */ public final class MockRequestMetadataCallback implements RequestMetadataCallback { Map> metadata; Throwable exception; + CountDownLatch latch = new CountDownLatch(1); /** Called when metadata is successfully produced. */ @Override public void onSuccess(Map> metadata) { checkNotSet(); this.metadata = metadata; + latch.countDown(); } /** Called when metadata generation failed. */ @@ -53,11 +56,22 @@ public void onSuccess(Map> metadata) { public void onFailure(Throwable exception) { checkNotSet(); this.exception = exception; + latch.countDown(); } public void reset() { this.metadata = null; this.exception = null; + latch = new CountDownLatch(1); + } + + public Map> awaitResult() throws Throwable { + latch.await(); + if (exception != null) { + throw exception; + } else { + return metadata; + } } private void checkNotSet() { diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java b/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java index efa09d7c7..4b431d28a 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockTokenServerTransport.java @@ -42,6 +42,7 @@ import com.google.api.client.testing.http.MockLowLevelHttpRequest; import com.google.api.client.testing.http.MockLowLevelHttpResponse; import com.google.auth.TestUtils; +import com.google.common.util.concurrent.Futures; import java.io.IOException; import java.net.URI; import java.util.ArrayDeque; @@ -49,6 +50,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Queue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; /** Mock transport to simulate providing Google OAuth2 access tokens */ public class MockTokenServerTransport extends MockHttpTransport { @@ -62,8 +65,7 @@ public class MockTokenServerTransport extends MockHttpTransport { final Map codes = new HashMap(); URI tokenServerUri = OAuth2Utils.TOKEN_SERVER_URI; private IOException error; - private Queue responseErrorSequence = new ArrayDeque(); - private Queue responseSequence = new ArrayDeque(); + private final Queue> responseSequence = new ArrayDeque<>(); private int expiresInSeconds = 3600; public MockTokenServerTransport() {} @@ -103,16 +105,20 @@ public void setError(IOException error) { public void addResponseErrorSequence(IOException... errors) { for (IOException error : errors) { - responseErrorSequence.add(error); + responseSequence.add(Futures.immediateFailedFuture(error)); } } public void addResponseSequence(LowLevelHttpResponse... responses) { for (LowLevelHttpResponse response : responses) { - responseSequence.add(response); + responseSequence.add(Futures.immediateFuture(response)); } } + public void addResponseSequence(Future response) { + responseSequence.add(response); + } + public void setExpiresInSeconds(int expiresInSeconds) { this.expiresInSeconds = expiresInSeconds; } @@ -130,14 +136,19 @@ public LowLevelHttpRequest buildRequest(String method, String url) throws IOExce return new MockLowLevelHttpRequest(url) { @Override public LowLevelHttpResponse execute() throws IOException { - IOException responseError = responseErrorSequence.poll(); - if (responseError != null) { - throw responseError; - } - LowLevelHttpResponse response = responseSequence.poll(); - if (response != null) { - return response; + + if (!responseSequence.isEmpty()) { + try { + return responseSequence.poll().get(); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + throw (IOException) cause; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Unexpectedly interrupted"); + } } + String content = this.getContentAsString(); Map query = TestUtils.parseQuery(content); String accessToken; diff --git a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java index 2f8c7d68a..f4eb16812 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/OAuth2CredentialsTest.java @@ -31,11 +31,13 @@ package com.google.auth.oauth2; +import static java.util.concurrent.TimeUnit.HOURS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -44,13 +46,31 @@ import com.google.auth.TestUtils; import com.google.auth.http.AuthHttpConstants; import com.google.auth.oauth2.GoogleCredentialsTest.MockTokenServerTransportFactory; +import com.google.auth.oauth2.OAuth2Credentials.OAuthValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.primitives.Ints; +import com.google.common.util.concurrent.ListenableFutureTask; +import com.google.common.util.concurrent.SettableFuture; import java.io.IOException; import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Date; import java.util.List; import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; import org.junit.Test; +import org.junit.function.ThrowingRunnable; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -64,6 +84,18 @@ public class OAuth2CredentialsTest extends BaseSerializationTest { private static final String ACCESS_TOKEN = "aashpFjkMkSJoj1xsli0H2eL5YsMgU_NKPY2TyGWY"; private static final URI CALL_URI = URI.create("http://googleapis.com/testapi/v1/foo"); + private ExecutorService realExecutor; + + @Before + public void setUp() { + realExecutor = Executors.newCachedThreadPool(); + } + + @After + public void tearDown() { + realExecutor.shutdown(); + } + @Test public void constructor_storesAccessToken() { OAuth2Credentials credentials = @@ -304,13 +336,14 @@ public void getRequestMetadata_async() throws IOException { } @Test - public void getRequestMetadata_async_refreshRace() throws IOException { + public void getRequestMetadata_async_refreshRace() + throws ExecutionException, InterruptedException { final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2"; MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET); transportFactory.transport.addRefreshToken(REFRESH_TOKEN, accessToken1); TestClock clock = new TestClock(); - OAuth2Credentials credentials = + final OAuth2Credentials credentials = UserCredentials.newBuilder() .setClientId(CLIENT_ID) .setClientSecret(CLIENT_SECRET) @@ -326,14 +359,37 @@ public void getRequestMetadata_async_refreshRace() throws IOException { assertEquals(0, transportFactory.transport.buildRequestCount); assertNull(callback.metadata); - // Asynchronous task is scheduled, but beaten by another blocking get call. + // Asynchronous task is scheduled, and a blocking call follows it assertEquals(1, executor.numTasks()); - Map> metadata = credentials.getRequestMetadata(CALL_URI); - assertEquals(1, transportFactory.transport.buildRequestCount--); - TestUtils.assertContainsBearerToken(metadata, accessToken1); - // When the task is run, the cached data is used. + ExecutorService testExecutor = Executors.newFixedThreadPool(1); + + FutureTask>> blockingTask = + new FutureTask<>( + new Callable>>() { + @Override + public Map> call() throws Exception { + return credentials.getRequestMetadata(CALL_URI); + } + }); + + @SuppressWarnings("FutureReturnValueIgnored") + Future ignored = testExecutor.submit(blockingTask); + testExecutor.shutdown(); + + // give the blockingTask a chance to run + for (int i = 0; i < 10; i++) { + Thread.yield(); + } + + // blocking task is waiting on the async task to finish + assertFalse(blockingTask.isDone()); + assertEquals(0, transportFactory.transport.buildRequestCount); + + // When the task is run, the result is shared assertEquals(1, executor.runTasksExhaustively()); + assertEquals(1, transportFactory.transport.buildRequestCount--); + Map> metadata = blockingTask.get(); assertEquals(0, transportFactory.transport.buildRequestCount); assertEquals(metadata, callback.metadata); } @@ -348,6 +404,267 @@ public void getRequestMetadata_temporaryToken_hasToken() throws IOException { TestUtils.assertContainsBearerToken(metadata, ACCESS_TOKEN); } + @Test + public void getRequestMetadata_staleTemporaryToken() throws IOException, InterruptedException { + Calendar calendar = Calendar.getInstance(); + Date actualExpiration = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); + Date clientStale = calendar.getTime(); + + TestClock testClock = new TestClock(); + testClock.setCurrentTime(clientStale.getTime()); + + // Initialize credentials which are initially stale and set to refresh + final SettableFuture refreshedTokenFuture = SettableFuture.create(); + OAuth2Credentials creds = + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, actualExpiration)) { + @Override + public AccessToken refreshAccessToken() { + + try { + return refreshedTokenFuture.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + creds.clock = testClock; + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + + // Calls should return immediately with stale token + MockRequestMetadataCallback callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN); + TestUtils.assertContainsBearerToken(creds.getRequestMetadata(CALL_URI), ACCESS_TOKEN); + + // But a refresh task should be scheduled + synchronized (creds.lock) { + assertNotNull(creds.refreshTask); + } + + // Resolve the outstanding refresh + AccessToken refreshedToken = + new AccessToken( + "2/MkSJoj1xsli0AccessToken_NKPY2", + new Date(testClock.currentTimeMillis() + HOURS.toMillis(1))); + refreshedTokenFuture.set(refreshedToken); + + // The access token should available once the refresh thread completes + // However it will be populated asynchronously, so we need to wait until it propagates + // Wait at most 1 minute are 100ms intervals. It should never come close to this. + for (int i = 0; i < 600; i++) { + Map> requestMetadata = creds.getRequestMetadata(CALL_URI); + String s = requestMetadata.get(AuthHttpConstants.AUTHORIZATION).get(0); + if (s.contains(refreshedToken.getTokenValue())) { + break; + } + Thread.sleep(100); + } + + // Everything should return the new token + callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.metadata, refreshedToken.getTokenValue()); + TestUtils.assertContainsBearerToken( + creds.getRequestMetadata(CALL_URI), refreshedToken.getTokenValue()); + + // And the task slot is reset + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + } + + @Test + public void getRequestMetadata_staleTemporaryToken_expirationWaits() throws Throwable { + Calendar calendar = Calendar.getInstance(); + Date actualExpiration = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); + Date clientStale = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.MINIMUM_TOKEN_MILLISECONDS)); + Date clientExpired = calendar.getTime(); + + TestClock testClock = new TestClock(); + + // Initialize credentials which are initially stale and set to refresh + final SettableFuture refreshedTokenFuture = SettableFuture.create(); + OAuth2Credentials creds = + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, actualExpiration)) { + @Override + public AccessToken refreshAccessToken() { + + try { + return refreshedTokenFuture.get(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + creds.clock = testClock; + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + + // Calls should return immediately with stale token, but a refresh is scheduled + testClock.setCurrentTime(clientStale.getTime()); + MockRequestMetadataCallback callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.metadata, ACCESS_TOKEN); + assertNotNull(creds.refreshTask); + ListenableFutureTask refreshTask = creds.refreshTask; + + // Fast forward to expiration, which will hang cause the callback to hang + testClock.setCurrentTime(clientExpired.getTime()); + // Make sure that the callback is hung (while giving it a chance to run) + for (int i = 0; i < 10; i++) { + Thread.sleep(10); + callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + assertNull(callback.metadata); + } + // The original refresh task should still be active + synchronized (creds.lock) { + assertSame(refreshTask, creds.refreshTask); + } + + // Resolve the outstanding refresh + AccessToken refreshedToken = + new AccessToken( + "2/MkSJoj1xsli0AccessToken_NKPY2", + new Date(testClock.currentTimeMillis() + HOURS.toMillis(1))); + refreshedTokenFuture.set(refreshedToken); + + // The access token should available once the refresh thread completes + TestUtils.assertContainsBearerToken( + creds.getRequestMetadata(CALL_URI), refreshedToken.getTokenValue()); + callback = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback); + TestUtils.assertContainsBearerToken(callback.awaitResult(), refreshedToken.getTokenValue()); + + // The refresh slot should be cleared + synchronized (creds.lock) { + assertNull(creds.refreshTask); + } + } + + @Test + public void getRequestMetadata_singleFlightErrorSharing() { + Calendar calendar = Calendar.getInstance(); + Date actualExpiration = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.REFRESH_MARGIN_MILLISECONDS)); + Date clientStale = calendar.getTime(); + + calendar.setTime(actualExpiration); + calendar.add( + Calendar.MILLISECOND, -1 * Ints.checkedCast(OAuth2Credentials.MINIMUM_TOKEN_MILLISECONDS)); + Date clientExpired = calendar.getTime(); + + TestClock testClock = new TestClock(); + testClock.setCurrentTime(clientStale.getTime()); + + // Initialize credentials which are initially expired + final SettableFuture refreshErrorFuture = SettableFuture.create(); + final OAuth2Credentials creds = + new OAuth2Credentials(new AccessToken(ACCESS_TOKEN, clientExpired)) { + @Override + public AccessToken refreshAccessToken() { + RuntimeException injectedError; + + try { + injectedError = refreshErrorFuture.get(); + } catch (Exception e) { + throw new IllegalStateException("Unexpected error fetching injected error"); + } + throw injectedError; + } + }; + creds.clock = testClock; + + // Calls will hang waiting for the refresh + final MockRequestMetadataCallback callback1 = new MockRequestMetadataCallback(); + creds.getRequestMetadata(CALL_URI, realExecutor, callback1); + + final Future>> blockingCall = + realExecutor.submit( + new Callable>>() { + @Override + public Map> call() throws Exception { + return creds.getRequestMetadata(CALL_URI); + } + }); + + RuntimeException error = new RuntimeException("fake error"); + refreshErrorFuture.set(error); + + // Get the error that getRequestMetadata(uri) created + Throwable actualBlockingError = + assertThrows( + ExecutionException.class, + new ThrowingRunnable() { + @Override + public void run() throws Throwable { + blockingCall.get(); + } + }) + .getCause(); + + assertEquals(error, actualBlockingError); + + RuntimeException actualAsyncError = + assertThrows( + RuntimeException.class, + new ThrowingRunnable() { + @Override + public void run() throws Throwable { + callback1.awaitResult(); + } + }); + assertEquals(error, actualAsyncError); + } + + @Test + public void getRequestMetadata_syncErrorsIncludeCallingStackframe() { + final OAuth2Credentials creds = + new OAuth2Credentials() { + @Override + public AccessToken refreshAccessToken() { + throw new RuntimeException("fake error"); + } + }; + + List expectedStacktrace = + new ArrayList<>(Arrays.asList(new Exception().getStackTrace())); + expectedStacktrace = expectedStacktrace.subList(1, expectedStacktrace.size()); + + AtomicReference actualError = new AtomicReference<>(); + try { + creds.getRequestMetadata(CALL_URI); + } catch (Exception refreshError) { + actualError.set(refreshError); + } + + List actualStacktrace = Arrays.asList(actualError.get().getStackTrace()); + actualStacktrace = + actualStacktrace.subList( + actualStacktrace.size() - expectedStacktrace.size(), actualStacktrace.size()); + + // ensure the remaining frames are identical + assertEquals(expectedStacktrace, actualStacktrace); + } + @Test public void refresh_refreshesToken() throws IOException { final String accessToken1 = "1/MkSJoj1xsli0AccessToken_NKPY2";