Skip to content

Commit

Permalink
fix: quotaProjectId should be applied for cached `getRequestMetadata(…
Browse files Browse the repository at this point in the history
…URI, Executor, RequestMetadataCallback)` (#509)

* test: add failing test for request metadata with callback

* refactor: use protected getAdditionalHeaders to provide quotaProjectId and other headers

* docs: add javadoc for new protected getAdditionalHeaders()

* test: add more tests for quotaProjectId

* docs: fix nit javadoc return format
  • Loading branch information
chingor13 committed Dec 3, 2020
1 parent 3b50627 commit 0a8412f
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 12 deletions.
28 changes: 22 additions & 6 deletions oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java
Expand Up @@ -38,6 +38,7 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.io.IOException;
import java.io.ObjectInputStream;
Expand All @@ -56,6 +57,7 @@ public class OAuth2Credentials extends Credentials {

private static final long serialVersionUID = 4556936364828217687L;
private static final long MINIMUM_TOKEN_MILLISECONDS = 60000L * 5L;
private static final Map<String, List<String>> EMPTY_EXTRA_HEADERS = Collections.emptyMap();

// byte[] is serializable, so the lock variable can be final
private final Object lock = new byte[0];
Expand Down Expand Up @@ -89,7 +91,7 @@ protected OAuth2Credentials() {
*/
protected OAuth2Credentials(AccessToken accessToken) {
if (accessToken != null) {
useAccessToken(accessToken);
useAccessToken(accessToken, EMPTY_EXTRA_HEADERS);
}
}

Expand Down Expand Up @@ -154,7 +156,9 @@ public void refresh() throws IOException {
synchronized (lock) {
requestMetadata = null;
temporaryAccess = null;
useAccessToken(Preconditions.checkNotNull(refreshAccessToken(), "new access token"));
useAccessToken(
Preconditions.checkNotNull(refreshAccessToken(), "new access token"),
getAdditionalHeaders());
if (changeListeners != null) {
for (CredentialsChangedListener listener : changeListeners) {
listener.onChanged(this);
Expand All @@ -163,6 +167,15 @@ public void refresh() throws IOException {
}
}

/**
* Provide additional headers to return as request metadata.
*
* @return additional headers
*/
protected Map<String, List<String>> getAdditionalHeaders() {
return EMPTY_EXTRA_HEADERS;
}

/**
* Refresh these credentials only if they have expired or are expiring imminently.
*
Expand All @@ -177,12 +190,15 @@ public void refreshIfExpired() throws IOException {
}

// Must be called under lock
private void useAccessToken(AccessToken token) {
private void useAccessToken(AccessToken token, Map<String, List<String>> additionalHeaders) {
this.temporaryAccess = token;
this.requestMetadata =
Collections.singletonMap(
AuthHttpConstants.AUTHORIZATION,
Collections.singletonList(OAuth2Utils.BEARER_PREFIX + token.getTokenValue()));
ImmutableMap.<String, List<String>>builder()
.put(
AuthHttpConstants.AUTHORIZATION,
Collections.singletonList(OAuth2Utils.BEARER_PREFIX + token.getTokenValue()))
.putAll(additionalHeaders)
.build();
}

// Must be called under lock
Expand Down
Expand Up @@ -599,9 +599,12 @@ public JwtCredentials jwtWithClaims(JwtClaims newClaims) {
}

@Override
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
Map<String, List<String>> requestMetadata = super.getRequestMetadata(uri);
return addQuotaProjectIdToRequestMetadata(quotaProjectId, requestMetadata);
protected Map<String, List<String>> getAdditionalHeaders() {
Map<String, List<String>> headers = super.getAdditionalHeaders();
if (quotaProjectId != null) {
return addQuotaProjectIdToRequestMetadata(quotaProjectId, headers);
}
return headers;
}

@Override
Expand Down
9 changes: 6 additions & 3 deletions oauth2_http/java/com/google/auth/oauth2/UserCredentials.java
Expand Up @@ -277,9 +277,12 @@ public void save(String filePath) throws IOException {
}

@Override
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
Map<String, List<String>> requestMetadata = super.getRequestMetadata(uri);
return addQuotaProjectIdToRequestMetadata(quotaProjectId, requestMetadata);
protected Map<String, List<String>> getAdditionalHeaders() {
Map<String, List<String>> headers = super.getAdditionalHeaders();
if (quotaProjectId != null) {
return addQuotaProjectIdToRequestMetadata(quotaProjectId, headers);
}
return headers;
}

@Override
Expand Down
Expand Up @@ -49,6 +49,7 @@
import com.google.api.client.testing.http.MockLowLevelHttpResponse;
import com.google.api.client.util.Clock;
import com.google.api.client.util.Joiner;
import com.google.auth.RequestMetadataCallback;
import com.google.auth.TestUtils;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.GoogleCredentialsTest.MockHttpTransportFactory;
Expand All @@ -68,6 +69,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -1032,6 +1034,97 @@ public void fromStream_noPrivateKeyId_throws() throws IOException {
testFromStreamException(serviceAccountStream, "private_key_id");
}

@Test
public void getRequestMetadataSetsQuotaProjectId() throws IOException {
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
transportFactory.transport.addClient(CLIENT_ID, "unused-client-secret");
transportFactory.transport.addServiceAccount(CLIENT_EMAIL, ACCESS_TOKEN);

PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(PRIVATE_KEY_PKCS8);
GoogleCredentials credentials =
ServiceAccountCredentials.newBuilder()
.setClientId(CLIENT_ID)
.setClientEmail(CLIENT_EMAIL)
.setPrivateKey(privateKey)
.setPrivateKeyId(PRIVATE_KEY_ID)
.setScopes(SCOPES)
.setServiceAccountUser(USER)
.setProjectId(PROJECT_ID)
.setQuotaProjectId("my-quota-project-id")
.setHttpTransportFactory(transportFactory)
.build();

Map<String, List<String>> metadata = credentials.getRequestMetadata();
assertTrue(metadata.containsKey("x-goog-user-project"));
List<String> headerValues = metadata.get("x-goog-user-project");
assertEquals(1, headerValues.size());
assertEquals("my-quota-project-id", headerValues.get(0));
}

@Test
public void getRequestMetadataNoQuotaProjectId() throws IOException {
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
transportFactory.transport.addClient(CLIENT_ID, "unused-client-secret");
transportFactory.transport.addServiceAccount(CLIENT_EMAIL, ACCESS_TOKEN);

PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(PRIVATE_KEY_PKCS8);
GoogleCredentials credentials =
ServiceAccountCredentials.newBuilder()
.setClientId(CLIENT_ID)
.setClientEmail(CLIENT_EMAIL)
.setPrivateKey(privateKey)
.setPrivateKeyId(PRIVATE_KEY_ID)
.setScopes(SCOPES)
.setServiceAccountUser(USER)
.setProjectId(PROJECT_ID)
.setHttpTransportFactory(transportFactory)
.build();

Map<String, List<String>> metadata = credentials.getRequestMetadata();
assertFalse(metadata.containsKey("x-goog-user-project"));
}

@Test
public void getRequestMetadataWithCallback() throws IOException {
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
transportFactory.transport.addClient(CLIENT_ID, "unused-client-secret");
transportFactory.transport.addServiceAccount(CLIENT_EMAIL, ACCESS_TOKEN);

PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(PRIVATE_KEY_PKCS8);
GoogleCredentials credentials =
ServiceAccountCredentials.newBuilder()
.setClientId(CLIENT_ID)
.setClientEmail(CLIENT_EMAIL)
.setPrivateKey(privateKey)
.setPrivateKeyId(PRIVATE_KEY_ID)
.setScopes(SCOPES)
.setServiceAccountUser(USER)
.setProjectId(PROJECT_ID)
.setQuotaProjectId("my-quota-project-id")
.setHttpTransportFactory(transportFactory)
.build();

final Map<String, List<String>> plainMetadata = credentials.getRequestMetadata();
final AtomicBoolean success = new AtomicBoolean(false);
credentials.getRequestMetadata(
null,
null,
new RequestMetadataCallback() {
@Override
public void onSuccess(Map<String, List<String>> metadata) {
assertEquals(plainMetadata, metadata);
success.set(true);
}

@Override
public void onFailure(Throwable exception) {
fail("Should not throw a failure.");
}
});

assertTrue("Should have run onSuccess() callback", success.get());
}

static GenericJson writeServiceAccountJson(
String clientId,
String clientEmail,
Expand Down
Expand Up @@ -46,6 +46,7 @@
import com.google.api.client.json.webtoken.JsonWebSignature;
import com.google.api.client.util.Clock;
import com.google.auth.Credentials;
import com.google.auth.RequestMetadataCallback;
import com.google.auth.TestClock;
import com.google.auth.http.AuthHttpConstants;
import com.google.auth.oauth2.GoogleCredentialsTest.MockHttpTransportFactory;
Expand All @@ -61,6 +62,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand Down Expand Up @@ -758,6 +760,76 @@ public void jwtWithClaims_defaultAudience() throws IOException {
verifyJwtAccess(metadata, SA_CLIENT_EMAIL, URI.create("default-audience"), SA_PRIVATE_KEY_ID);
}

@Test
public void getRequestMetadataSetsQuotaProjectId() throws IOException {
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
ServiceAccountJwtAccessCredentials credentials =
ServiceAccountJwtAccessCredentials.newBuilder()
.setClientId(SA_CLIENT_ID)
.setClientEmail(SA_CLIENT_EMAIL)
.setPrivateKey(privateKey)
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
.setQuotaProjectId("my-quota-project-id")
.setDefaultAudience(URI.create("default-audience"))
.build();

Map<String, List<String>> metadata = credentials.getRequestMetadata();
assertTrue(metadata.containsKey("x-goog-user-project"));
List<String> headerValues = metadata.get("x-goog-user-project");
assertEquals(1, headerValues.size());
assertEquals("my-quota-project-id", headerValues.get(0));
}

@Test
public void getRequestMetadataNoQuotaProjectId() throws IOException {
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
ServiceAccountJwtAccessCredentials credentials =
ServiceAccountJwtAccessCredentials.newBuilder()
.setClientId(SA_CLIENT_ID)
.setClientEmail(SA_CLIENT_EMAIL)
.setPrivateKey(privateKey)
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
.setDefaultAudience(URI.create("default-audience"))
.build();

Map<String, List<String>> metadata = credentials.getRequestMetadata();
assertFalse(metadata.containsKey("x-goog-user-project"));
}

@Test
public void getRequestMetadataWithCallback() throws IOException {
PrivateKey privateKey = ServiceAccountCredentials.privateKeyFromPkcs8(SA_PRIVATE_KEY_PKCS8);
ServiceAccountJwtAccessCredentials credentials =
ServiceAccountJwtAccessCredentials.newBuilder()
.setClientId(SA_CLIENT_ID)
.setClientEmail(SA_CLIENT_EMAIL)
.setPrivateKey(privateKey)
.setPrivateKeyId(SA_PRIVATE_KEY_ID)
.setQuotaProjectId("my-quota-project-id")
.setDefaultAudience(URI.create("default-audience"))
.build();

final Map<String, List<String>> plainMetadata = credentials.getRequestMetadata();
final AtomicBoolean success = new AtomicBoolean(false);
credentials.getRequestMetadata(
null,
null,
new RequestMetadataCallback() {
@Override
public void onSuccess(Map<String, List<String>> metadata) {
assertEquals(plainMetadata, metadata);
success.set(true);
}

@Override
public void onFailure(Throwable exception) {
fail("Should not throw a failure.");
}
});

assertTrue("Should have run onSuccess() callback", success.get());
}

private void verifyJwtAccess(
Map<String, List<String>> metadata,
String expectedEmail,
Expand Down

0 comments on commit 0a8412f

Please sign in to comment.