Skip to content

Commit

Permalink
feat: allow the AWS_DEFAULT_REGION environment variable (#721)
Browse files Browse the repository at this point in the history
Amazon has this variable documented, and apparently people are trying to 
use it, so we should support it
  • Loading branch information
ScruffyProdigy committed Mar 13, 2021
1 parent d80c85f commit 199da47
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
14 changes: 9 additions & 5 deletions google/auth/aws.py
Expand Up @@ -424,9 +424,9 @@ def retrieve_subject_token(self, request):
The logic is summarized as:
Retrieve the AWS region from the AWS_REGION environment variable or from
the AWS metadata server availability-zone if not found in the
environment variable.
Retrieve the AWS region from the AWS_REGION or AWS_DEFAULT_REGION
environment variable or from the AWS metadata server availability-zone
if not found in the environment variable.
Check AWS credentials in environment variables. If not found, retrieve
from the AWS metadata server security-credentials endpoint.
Expand Down Expand Up @@ -504,8 +504,8 @@ def retrieve_subject_token(self, request):
)

def _get_region(self, request, url):
"""Retrieves the current AWS region from either the AWS_REGION
environment variable or from the AWS metadata server.
"""Retrieves the current AWS region from either the AWS_REGION or
AWS_DEFAULT_REGION environment variable or from the AWS metadata server.
Args:
request (google.auth.transport.Request): A callable used to make
Expand All @@ -526,6 +526,10 @@ def _get_region(self, request, url):
if env_aws_region is not None:
return env_aws_region

env_aws_region = os.environ.get(environment_vars.AWS_DEFAULT_REGION)
if env_aws_region is not None:
return env_aws_region

if not self._region_url:
raise exceptions.RefreshError("Unable to determine AWS region")
response = request(url=self._region_url, method="GET")
Expand Down
1 change: 1 addition & 0 deletions google/auth/environment_vars.py
Expand Up @@ -69,3 +69,4 @@
AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY"
AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN"
AWS_REGION = "AWS_REGION"
AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION"
50 changes: 50 additions & 0 deletions tests/test_aws.py
Expand Up @@ -1043,6 +1043,56 @@ def test_retrieve_subject_token_success_environment_vars(self, utcnow, monkeypat
}
)

@mock.patch("google.auth._helpers.utcnow")
def test_retrieve_subject_token_success_environment_vars_with_default_region(
self, utcnow, monkeypatch
):
monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID)
monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY)
monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN)
monkeypatch.setenv(environment_vars.AWS_DEFAULT_REGION, self.AWS_REGION)
utcnow.return_value = datetime.datetime.strptime(
self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
)
credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)

subject_token = credentials.retrieve_subject_token(None)

assert subject_token == self.make_serialized_aws_signed_request(
{
"access_key_id": ACCESS_KEY_ID,
"secret_access_key": SECRET_ACCESS_KEY,
"security_token": TOKEN,
}
)

@mock.patch("google.auth._helpers.utcnow")
def test_retrieve_subject_token_success_environment_vars_with_both_regions_set(
self, utcnow, monkeypatch
):
monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID)
monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY)
monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN)
monkeypatch.setenv(environment_vars.AWS_DEFAULT_REGION, "Malformed AWS Region")
# This test makes sure that the AWS_REGION gets used over AWS_DEFAULT_REGION,
# So, AWS_DEFAULT_REGION is set to something that would cause the test to fail,
# And AWS_REGION is set to the a valid value, and it should succeed
monkeypatch.setenv(environment_vars.AWS_REGION, self.AWS_REGION)
utcnow.return_value = datetime.datetime.strptime(
self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ"
)
credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE)

subject_token = credentials.retrieve_subject_token(None)

assert subject_token == self.make_serialized_aws_signed_request(
{
"access_key_id": ACCESS_KEY_ID,
"secret_access_key": SECRET_ACCESS_KEY,
"security_token": TOKEN,
}
)

@mock.patch("google.auth._helpers.utcnow")
def test_retrieve_subject_token_success_environment_vars_no_session_token(
self, utcnow, monkeypatch
Expand Down

0 comments on commit 199da47

Please sign in to comment.