From 0a8412fcf9de4ac568b9f88618e44087dd31b144 Mon Sep 17 00:00:00 2001 From: Jeff Ching Date: Thu, 3 Dec 2020 13:06:36 -0800 Subject: [PATCH] fix: quotaProjectId should be applied for cached `getRequestMetadata(URI, Executor, RequestMetadataCallback)` (#509) * test: add failing test for request metadata with callback * refactor: use protected getAdditionalHeaders to provide quotaProjectId and other headers * docs: add javadoc for new protected getAdditionalHeaders() * test: add more tests for quotaProjectId * docs: fix nit javadoc return format --- .../google/auth/oauth2/OAuth2Credentials.java | 28 ++++-- .../oauth2/ServiceAccountCredentials.java | 9 +- .../google/auth/oauth2/UserCredentials.java | 9 +- .../oauth2/ServiceAccountCredentialsTest.java | 93 +++++++++++++++++++ ...erviceAccountJwtAccessCredentialsTest.java | 72 ++++++++++++++ .../auth/oauth2/UserCredentialsTest.java | 77 +++++++++++++++ 6 files changed, 276 insertions(+), 12 deletions(-) diff --git a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java index aca1e3490..ab6c042da 100644 --- a/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java @@ -38,6 +38,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import java.io.IOException; import java.io.ObjectInputStream; @@ -56,6 +57,7 @@ public class OAuth2Credentials extends Credentials { private static final long serialVersionUID = 4556936364828217687L; private static final long MINIMUM_TOKEN_MILLISECONDS = 60000L * 5L; + private static final Map> EMPTY_EXTRA_HEADERS = Collections.emptyMap(); // byte[] is serializable, so the lock variable can be final private final Object lock = new byte[0]; @@ -89,7 +91,7 @@ protected OAuth2Credentials() { */ protected OAuth2Credentials(AccessToken accessToken) { if (accessToken != null) { - useAccessToken(accessToken); + useAccessToken(accessToken, EMPTY_EXTRA_HEADERS); } } @@ -154,7 +156,9 @@ public void refresh() throws IOException { synchronized (lock) { requestMetadata = null; temporaryAccess = null; - useAccessToken(Preconditions.checkNotNull(refreshAccessToken(), "new access token")); + useAccessToken( + Preconditions.checkNotNull(refreshAccessToken(), "new access token"), + getAdditionalHeaders()); if (changeListeners != null) { for (CredentialsChangedListener listener : changeListeners) { listener.onChanged(this); @@ -163,6 +167,15 @@ public void refresh() throws IOException { } } + /** + * Provide additional headers to return as request metadata. + * + * @return additional headers + */ + protected Map> getAdditionalHeaders() { + return EMPTY_EXTRA_HEADERS; + } + /** * Refresh these credentials only if they have expired or are expiring imminently. * @@ -177,12 +190,15 @@ public void refreshIfExpired() throws IOException { } // Must be called under lock - private void useAccessToken(AccessToken token) { + private void useAccessToken(AccessToken token, Map> additionalHeaders) { this.temporaryAccess = token; this.requestMetadata = - Collections.singletonMap( - AuthHttpConstants.AUTHORIZATION, - Collections.singletonList(OAuth2Utils.BEARER_PREFIX + token.getTokenValue())); + ImmutableMap.>builder() + .put( + AuthHttpConstants.AUTHORIZATION, + Collections.singletonList(OAuth2Utils.BEARER_PREFIX + token.getTokenValue())) + .putAll(additionalHeaders) + .build(); } // Must be called under lock diff --git a/oauth2_http/java/com/google/auth/oauth2/ServiceAccountCredentials.java b/oauth2_http/java/com/google/auth/oauth2/ServiceAccountCredentials.java index f26448545..974959129 100644 --- a/oauth2_http/java/com/google/auth/oauth2/ServiceAccountCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/ServiceAccountCredentials.java @@ -599,9 +599,12 @@ public JwtCredentials jwtWithClaims(JwtClaims newClaims) { } @Override - public Map> getRequestMetadata(URI uri) throws IOException { - Map> requestMetadata = super.getRequestMetadata(uri); - return addQuotaProjectIdToRequestMetadata(quotaProjectId, requestMetadata); + protected Map> getAdditionalHeaders() { + Map> headers = super.getAdditionalHeaders(); + if (quotaProjectId != null) { + return addQuotaProjectIdToRequestMetadata(quotaProjectId, headers); + } + return headers; } @Override diff --git a/oauth2_http/java/com/google/auth/oauth2/UserCredentials.java b/oauth2_http/java/com/google/auth/oauth2/UserCredentials.java index 5010a9ae6..bd8dcfd61 100644 --- a/oauth2_http/java/com/google/auth/oauth2/UserCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/UserCredentials.java @@ -277,9 +277,12 @@ public void save(String filePath) throws IOException { } @Override - public Map> getRequestMetadata(URI uri) throws IOException { - Map> requestMetadata = super.getRequestMetadata(uri); - return addQuotaProjectIdToRequestMetadata(quotaProjectId, requestMetadata); + protected Map> getAdditionalHeaders() { + Map> headers = super.getAdditionalHeaders(); + if (quotaProjectId != null) { + return addQuotaProjectIdToRequestMetadata(quotaProjectId, headers); + } + return headers; } @Override diff --git a/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountCredentialsTest.java index 779c7f006..7d89ded8f 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountCredentialsTest.java @@ -49,6 +49,7 @@ import com.google.api.client.testing.http.MockLowLevelHttpResponse; import com.google.api.client.util.Clock; import com.google.api.client.util.Joiner; +import com.google.auth.RequestMetadataCallback; import com.google.auth.TestUtils; import com.google.auth.http.HttpTransportFactory; import com.google.auth.oauth2.GoogleCredentialsTest.MockHttpTransportFactory; @@ -68,6 +69,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -1032,6 +1034,97 @@ public void fromStream_noPrivateKeyId_throws() throws IOException { testFromStreamException(serviceAccountStream, "private_key_id"); } + @Test + public void getRequestMetadataSetsQuotaProjectId() throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID, "unused-client-secret"); + transportFactory.transport.addServiceAccount(CLIENT_EMAIL, ACCESS_TOKEN); + + PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(PRIVATE_KEY_PKCS8); + GoogleCredentials credentials = + ServiceAccountCredentials.newBuilder() + .setClientId(CLIENT_ID) + .setClientEmail(CLIENT_EMAIL) + .setPrivateKey(privateKey) + .setPrivateKeyId(PRIVATE_KEY_ID) + .setScopes(SCOPES) + .setServiceAccountUser(USER) + .setProjectId(PROJECT_ID) + .setQuotaProjectId("my-quota-project-id") + .setHttpTransportFactory(transportFactory) + .build(); + + Map> metadata = credentials.getRequestMetadata(); + assertTrue(metadata.containsKey("x-goog-user-project")); + List headerValues = metadata.get("x-goog-user-project"); + assertEquals(1, headerValues.size()); + assertEquals("my-quota-project-id", headerValues.get(0)); + } + + @Test + public void getRequestMetadataNoQuotaProjectId() throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID, "unused-client-secret"); + transportFactory.transport.addServiceAccount(CLIENT_EMAIL, ACCESS_TOKEN); + + PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(PRIVATE_KEY_PKCS8); + GoogleCredentials credentials = + ServiceAccountCredentials.newBuilder() + .setClientId(CLIENT_ID) + .setClientEmail(CLIENT_EMAIL) + .setPrivateKey(privateKey) + .setPrivateKeyId(PRIVATE_KEY_ID) + .setScopes(SCOPES) + .setServiceAccountUser(USER) + .setProjectId(PROJECT_ID) + .setHttpTransportFactory(transportFactory) + .build(); + + Map> metadata = credentials.getRequestMetadata(); + assertFalse(metadata.containsKey("x-goog-user-project")); + } + + @Test + public void getRequestMetadataWithCallback() throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID, "unused-client-secret"); + transportFactory.transport.addServiceAccount(CLIENT_EMAIL, ACCESS_TOKEN); + + PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(PRIVATE_KEY_PKCS8); + GoogleCredentials credentials = + ServiceAccountCredentials.newBuilder() + .setClientId(CLIENT_ID) + .setClientEmail(CLIENT_EMAIL) + .setPrivateKey(privateKey) + .setPrivateKeyId(PRIVATE_KEY_ID) + .setScopes(SCOPES) + .setServiceAccountUser(USER) + .setProjectId(PROJECT_ID) + .setQuotaProjectId("my-quota-project-id") + .setHttpTransportFactory(transportFactory) + .build(); + + final Map> plainMetadata = credentials.getRequestMetadata(); + final AtomicBoolean success = new AtomicBoolean(false); + credentials.getRequestMetadata( + null, + null, + new RequestMetadataCallback() { + @Override + public void onSuccess(Map> metadata) { + assertEquals(plainMetadata, metadata); + success.set(true); + } + + @Override + public void onFailure(Throwable exception) { + fail("Should not throw a failure."); + } + }); + + assertTrue("Should have run onSuccess() callback", success.get()); + } + static GenericJson writeServiceAccountJson( String clientId, String clientEmail, diff --git a/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountJwtAccessCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountJwtAccessCredentialsTest.java index 64ca7d408..c3d73072f 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountJwtAccessCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/ServiceAccountJwtAccessCredentialsTest.java @@ -46,6 +46,7 @@ import com.google.api.client.json.webtoken.JsonWebSignature; import com.google.api.client.util.Clock; import com.google.auth.Credentials; +import com.google.auth.RequestMetadataCallback; import com.google.auth.TestClock; import com.google.auth.http.AuthHttpConstants; import com.google.auth.oauth2.GoogleCredentialsTest.MockHttpTransportFactory; @@ -61,6 +62,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -758,6 +760,76 @@ public void jwtWithClaims_defaultAudience() throws IOException { verifyJwtAccess(metadata, SA_CLIENT_EMAIL, URI.create("default-audience"), SA_PRIVATE_KEY_ID); } + @Test + public void getRequestMetadataSetsQuotaProjectId() throws IOException { + PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8); + ServiceAccountJwtAccessCredentials credentials = + ServiceAccountJwtAccessCredentials.newBuilder() + .setClientId(SA_CLIENT_ID) + .setClientEmail(SA_CLIENT_EMAIL) + .setPrivateKey(privateKey) + .setPrivateKeyId(SA_PRIVATE_KEY_ID) + .setQuotaProjectId("my-quota-project-id") + .setDefaultAudience(URI.create("default-audience")) + .build(); + + Map> metadata = credentials.getRequestMetadata(); + assertTrue(metadata.containsKey("x-goog-user-project")); + List headerValues = metadata.get("x-goog-user-project"); + assertEquals(1, headerValues.size()); + assertEquals("my-quota-project-id", headerValues.get(0)); + } + + @Test + public void getRequestMetadataNoQuotaProjectId() throws IOException { + PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8); + ServiceAccountJwtAccessCredentials credentials = + ServiceAccountJwtAccessCredentials.newBuilder() + .setClientId(SA_CLIENT_ID) + .setClientEmail(SA_CLIENT_EMAIL) + .setPrivateKey(privateKey) + .setPrivateKeyId(SA_PRIVATE_KEY_ID) + .setDefaultAudience(URI.create("default-audience")) + .build(); + + Map> metadata = credentials.getRequestMetadata(); + assertFalse(metadata.containsKey("x-goog-user-project")); + } + + @Test + public void getRequestMetadataWithCallback() throws IOException { + PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8); + ServiceAccountJwtAccessCredentials credentials = + ServiceAccountJwtAccessCredentials.newBuilder() + .setClientId(SA_CLIENT_ID) + .setClientEmail(SA_CLIENT_EMAIL) + .setPrivateKey(privateKey) + .setPrivateKeyId(SA_PRIVATE_KEY_ID) + .setQuotaProjectId("my-quota-project-id") + .setDefaultAudience(URI.create("default-audience")) + .build(); + + final Map> plainMetadata = credentials.getRequestMetadata(); + final AtomicBoolean success = new AtomicBoolean(false); + credentials.getRequestMetadata( + null, + null, + new RequestMetadataCallback() { + @Override + public void onSuccess(Map> metadata) { + assertEquals(plainMetadata, metadata); + success.set(true); + } + + @Override + public void onFailure(Throwable exception) { + fail("Should not throw a failure."); + } + }); + + assertTrue("Should have run onSuccess() callback", success.get()); + } + private void verifyJwtAccess( Map> metadata, String expectedEmail, diff --git a/oauth2_http/javatests/com/google/auth/oauth2/UserCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/UserCredentialsTest.java index 92e8ebe73..288f4d773 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/UserCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/UserCredentialsTest.java @@ -40,6 +40,7 @@ import com.google.api.client.json.GenericJson; import com.google.api.client.util.Clock; +import com.google.auth.RequestMetadataCallback; import com.google.auth.TestUtils; import com.google.auth.http.AuthHttpConstants; import com.google.auth.oauth2.GoogleCredentialsTest.MockHttpTransportFactory; @@ -56,6 +57,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -622,6 +624,81 @@ public void saveAndRestoreUserCredential_saveAndRestored_throws() throws IOExcep assertEquals(userCredentials.getRefreshToken(), restoredCredentials.getRefreshToken()); } + @Test + public void getRequestMetadataSetsQuotaProjectId() throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET); + transportFactory.transport.addRefreshToken(REFRESH_TOKEN, ACCESS_TOKEN); + + UserCredentials userCredentials = + UserCredentials.newBuilder() + .setClientId(CLIENT_ID) + .setClientSecret(CLIENT_SECRET) + .setRefreshToken(REFRESH_TOKEN) + .setQuotaProjectId("my-quota-project-id") + .setHttpTransportFactory(transportFactory) + .build(); + + Map> metadata = userCredentials.getRequestMetadata(); + assertTrue(metadata.containsKey("x-goog-user-project")); + List headerValues = metadata.get("x-goog-user-project"); + assertEquals(1, headerValues.size()); + assertEquals("my-quota-project-id", headerValues.get(0)); + } + + @Test + public void getRequestMetadataNoQuotaProjectId() throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET); + transportFactory.transport.addRefreshToken(REFRESH_TOKEN, ACCESS_TOKEN); + + UserCredentials userCredentials = + UserCredentials.newBuilder() + .setClientId(CLIENT_ID) + .setClientSecret(CLIENT_SECRET) + .setRefreshToken(REFRESH_TOKEN) + .setHttpTransportFactory(transportFactory) + .build(); + + Map> metadata = userCredentials.getRequestMetadata(); + assertFalse(metadata.containsKey("x-goog-user-project")); + } + + @Test + public void getRequestMetadataWithCallback() throws IOException { + MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory(); + transportFactory.transport.addClient(CLIENT_ID, CLIENT_SECRET); + transportFactory.transport.addRefreshToken(REFRESH_TOKEN, ACCESS_TOKEN); + + UserCredentials userCredentials = + UserCredentials.newBuilder() + .setClientId(CLIENT_ID) + .setClientSecret(CLIENT_SECRET) + .setRefreshToken(REFRESH_TOKEN) + .setQuotaProjectId("my-quota-project-id") + .setHttpTransportFactory(transportFactory) + .build(); + final Map> plainMetadata = userCredentials.getRequestMetadata(); + final AtomicBoolean success = new AtomicBoolean(false); + userCredentials.getRequestMetadata( + null, + null, + new RequestMetadataCallback() { + @Override + public void onSuccess(Map> metadata) { + assertEquals(plainMetadata, metadata); + success.set(true); + } + + @Override + public void onFailure(Throwable exception) { + fail("Should not throw a failure."); + } + }); + + assertTrue("Should have run onSuccess() callback", success.get()); + } + static GenericJson writeUserJson( String clientId, String clientSecret, String refreshToken, String quotaProjectId) { GenericJson json = new GenericJson();