diff --git a/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java b/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java index f96cf096a..4f00517bc 100644 --- a/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java +++ b/oauth2_http/java/com/google/auth/oauth2/AwsCredentials.java @@ -243,13 +243,19 @@ private String buildSubjectToken(AwsRequestSignature signature) return URLEncoder.encode(token.toString(), "UTF-8"); } - private String getAwsRegion() throws IOException { + @VisibleForTesting + String getAwsRegion() throws IOException { // For AWS Lambda, the region is retrieved through the AWS_REGION environment variable. String region = getEnvironmentProvider().getEnv("AWS_REGION"); if (region != null) { return region; } + String defaultRegion = getEnvironmentProvider().getEnv("AWS_DEFAULT_REGION"); + if (defaultRegion != null) { + return defaultRegion; + } + if (awsCredentialSource.regionUrl == null || awsCredentialSource.regionUrl.isEmpty()) { throw new IOException( "Unable to determine the AWS region. The credential source does not contain the region URL."); diff --git a/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java b/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java index 7537c3098..1721fc5c1 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/AwsCredentialsTest.java @@ -345,6 +345,73 @@ public void getAwsSecurityCredentials_fromMetadataServer_noUrlProvided() { } } + @Test + public void getAwsRegion_awsRegionEnvironmentVariable() throws IOException { + TestEnvironmentProvider environmentProvider = new TestEnvironmentProvider(); + environmentProvider.setEnv("AWS_REGION", "region"); + environmentProvider.setEnv("AWS_DEFAULT_REGION", "defaultRegion"); + + MockExternalAccountCredentialsTransportFactory transportFactory = + new MockExternalAccountCredentialsTransportFactory(); + AwsCredentials awsCredentials = + (AwsCredentials) + AwsCredentials.newBuilder(AWS_CREDENTIAL) + .setHttpTransportFactory(transportFactory) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) + .setEnvironmentProvider(environmentProvider) + .build(); + + String region = awsCredentials.getAwsRegion(); + + // Should attempt to retrieve the region from AWS_REGION env var first. + // Metadata server would return us-east-1b. + assertEquals("region", region); + } + + @Test + public void getAwsRegion_awsDefaultRegionEnvironmentVariable() throws IOException { + TestEnvironmentProvider environmentProvider = new TestEnvironmentProvider(); + environmentProvider.setEnv("AWS_DEFAULT_REGION", "defaultRegion"); + + MockExternalAccountCredentialsTransportFactory transportFactory = + new MockExternalAccountCredentialsTransportFactory(); + AwsCredentials awsCredentials = + (AwsCredentials) + AwsCredentials.newBuilder(AWS_CREDENTIAL) + .setHttpTransportFactory(transportFactory) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) + .setEnvironmentProvider(environmentProvider) + .build(); + + String region = awsCredentials.getAwsRegion(); + + // Should attempt to retrieve the region from DEFAULT_AWS_REGION before calling the metadata + // server. Metadata server would return us-east-1b. + assertEquals("defaultRegion", region); + } + + @Test + public void getAwsRegion_metadataServer() throws IOException { + MockExternalAccountCredentialsTransportFactory transportFactory = + new MockExternalAccountCredentialsTransportFactory(); + AwsCredentials awsCredentials = + (AwsCredentials) + AwsCredentials.newBuilder(AWS_CREDENTIAL) + .setHttpTransportFactory(transportFactory) + .setCredentialSource(buildAwsCredentialSource(transportFactory)) + .build(); + + String region = awsCredentials.getAwsRegion(); + + // Should retrieve the region from the Metadata server. + String expectedRegion = + transportFactory + .transport + .getAwsRegion() + .substring(0, transportFactory.transport.getAwsRegion().length() - 1); + assertEquals(expectedRegion, region); + } + @Test public void createdScoped_clonedCredentialWithAddedScopes() { AwsCredentials credentials = diff --git a/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java b/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java index fc7e0cdb9..49e2b88be 100644 --- a/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java +++ b/oauth2_http/javatests/com/google/auth/oauth2/MockExternalAccountCredentialsTransport.java @@ -74,6 +74,7 @@ public class MockExternalAccountCredentialsTransport extends MockHttpTransport { private static final String TOKEN_TYPE = "Bearer"; private static final String ACCESS_TOKEN = "accessToken"; private static final String SERVICE_ACCOUNT_ACCESS_TOKEN = "serviceAccountAccessToken"; + private static final String AWS_REGION = "us-east-1b"; private static final Long EXPIRES_IN = 3600L; private static final JsonFactory JSON_FACTORY = new GsonFactory(); @@ -120,7 +121,7 @@ public LowLevelHttpResponse execute() throws IOException { if (AWS_REGION_URL.equals(url)) { return new MockLowLevelHttpResponse() .setContentType("text/html") - .setContent("us-east-1b"); + .setContent(AWS_REGION); } if (AWS_CREDENTIALS_URL.equals(url)) { return new MockLowLevelHttpResponse() @@ -245,6 +246,10 @@ public String getAwsRegionUrl() { return AWS_REGION_URL; } + public String getAwsRegion() { + return AWS_REGION; + } + public String getStsUrl() { return STS_URL; }