diff --git a/google/auth/aws.py b/google/auth/aws.py index b362dd315..c2b521c36 100644 --- a/google/auth/aws.py +++ b/google/auth/aws.py @@ -424,9 +424,9 @@ def retrieve_subject_token(self, request): The logic is summarized as: - Retrieve the AWS region from the AWS_REGION environment variable or from - the AWS metadata server availability-zone if not found in the - environment variable. + Retrieve the AWS region from the AWS_REGION or AWS_DEFAULT_REGION + environment variable or from the AWS metadata server availability-zone + if not found in the environment variable. Check AWS credentials in environment variables. If not found, retrieve from the AWS metadata server security-credentials endpoint. @@ -504,8 +504,8 @@ def retrieve_subject_token(self, request): ) def _get_region(self, request, url): - """Retrieves the current AWS region from either the AWS_REGION - environment variable or from the AWS metadata server. + """Retrieves the current AWS region from either the AWS_REGION or + AWS_DEFAULT_REGION environment variable or from the AWS metadata server. Args: request (google.auth.transport.Request): A callable used to make @@ -526,6 +526,10 @@ def _get_region(self, request, url): if env_aws_region is not None: return env_aws_region + env_aws_region = os.environ.get(environment_vars.AWS_DEFAULT_REGION) + if env_aws_region is not None: + return env_aws_region + if not self._region_url: raise exceptions.RefreshError("Unable to determine AWS region") response = request(url=self._region_url, method="GET") diff --git a/google/auth/environment_vars.py b/google/auth/environment_vars.py index 416bab0c0..f02774181 100644 --- a/google/auth/environment_vars.py +++ b/google/auth/environment_vars.py @@ -69,3 +69,4 @@ AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" AWS_SESSION_TOKEN = "AWS_SESSION_TOKEN" AWS_REGION = "AWS_REGION" +AWS_DEFAULT_REGION = "AWS_DEFAULT_REGION" diff --git a/tests/test_aws.py b/tests/test_aws.py index 7a55841ca..7c7ee36be 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -1043,6 +1043,56 @@ def test_retrieve_subject_token_success_environment_vars(self, utcnow, monkeypat } ) + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_environment_vars_with_default_region( + self, utcnow, monkeypatch + ): + monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID) + monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY) + monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN) + monkeypatch.setenv(environment_vars.AWS_DEFAULT_REGION, self.AWS_REGION) + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + + @mock.patch("google.auth._helpers.utcnow") + def test_retrieve_subject_token_success_environment_vars_with_both_regions_set( + self, utcnow, monkeypatch + ): + monkeypatch.setenv(environment_vars.AWS_ACCESS_KEY_ID, ACCESS_KEY_ID) + monkeypatch.setenv(environment_vars.AWS_SECRET_ACCESS_KEY, SECRET_ACCESS_KEY) + monkeypatch.setenv(environment_vars.AWS_SESSION_TOKEN, TOKEN) + monkeypatch.setenv(environment_vars.AWS_DEFAULT_REGION, "Malformed AWS Region") + # This test makes sure that the AWS_REGION gets used over AWS_DEFAULT_REGION, + # So, AWS_DEFAULT_REGION is set to something that would cause the test to fail, + # And AWS_REGION is set to the a valid value, and it should succeed + monkeypatch.setenv(environment_vars.AWS_REGION, self.AWS_REGION) + utcnow.return_value = datetime.datetime.strptime( + self.AWS_SIGNATURE_TIME, "%Y-%m-%dT%H:%M:%SZ" + ) + credentials = self.make_credentials(credential_source=self.CREDENTIAL_SOURCE) + + subject_token = credentials.retrieve_subject_token(None) + + assert subject_token == self.make_serialized_aws_signed_request( + { + "access_key_id": ACCESS_KEY_ID, + "secret_access_key": SECRET_ACCESS_KEY, + "security_token": TOKEN, + } + ) + @mock.patch("google.auth._helpers.utcnow") def test_retrieve_subject_token_success_environment_vars_no_session_token( self, utcnow, monkeypatch