diff --git a/google/auth/exceptions.py b/google/auth/exceptions.py index 57f181ea1..e9e737780 100644 --- a/google/auth/exceptions.py +++ b/google/auth/exceptions.py @@ -57,3 +57,7 @@ def __init__(self, message=None): super(ReauthFailError, self).__init__( "Reauthentication failed. {0}".format(message) ) + + +class ReauthSamlChallengeFailError(ReauthFailError): + """An exception for SAML reauth challenge failures.""" diff --git a/google/oauth2/_credentials_async.py b/google/oauth2/_credentials_async.py index b4878c543..e7b9637c8 100644 --- a/google/oauth2/_credentials_async.py +++ b/google/oauth2/_credentials_async.py @@ -75,6 +75,7 @@ async def refresh(self, request): self._client_secret, scopes=self._scopes, rapt_token=self._rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, ) self.token = access_token diff --git a/google/oauth2/_reauth_async.py b/google/oauth2/_reauth_async.py index 510578bf7..f74f50b43 100644 --- a/google/oauth2/_reauth_async.py +++ b/google/oauth2/_reauth_async.py @@ -248,6 +248,7 @@ async def refresh_grant( client_secret, scopes=None, rapt_token=None, + enable_reauth_refresh=False, ): """Implements the reauthentication flow. @@ -265,6 +266,9 @@ async def refresh_grant( token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). rapt_token (Optional(str)): The rapt token for reauth. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow + should be used. The default value is False. This option is for + gcloud only, other users should use the default value. Returns: Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The @@ -299,6 +303,11 @@ async def refresh_grant( == reauth._REAUTH_NEEDED_ERROR_RAPT_REQUIRED ) ): + if not enable_reauth_refresh: + raise exceptions.RefreshError( + "Reauthentication is needed. Please run `gcloud auth login --update-adc` to reauthenticate." + ) + rapt_token = await get_rapt_token( request, client_id, client_secret, refresh_token, token_uri, scopes=scopes ) diff --git a/google/oauth2/challenges.py b/google/oauth2/challenges.py index 7756a8057..0baff62e0 100644 --- a/google/oauth2/challenges.py +++ b/google/oauth2/challenges.py @@ -25,6 +25,9 @@ REAUTH_ORIGIN = "https://accounts.google.com" +SAML_CHALLENGE_MESSAGE = ( + "Please run `gcloud auth login` to complete reauthentication with SAML." +) def get_user_password(text): @@ -148,7 +151,30 @@ def obtain_challenge_input(self, metadata): return None +class SamlChallenge(ReauthChallenge): + """Challenge that asks the users to browse to their ID Providers. + + Currently SAML challenge is not supported. When obtaining the challenge + input, exception will be raised to instruct the users to run + `gcloud auth login` for reauthentication. + """ + + @property + def name(self): + return "SAML" + + @property + def is_locally_eligible(self): + return True + + def obtain_challenge_input(self, metadata): + # Magic Arch has not fully supported returning a proper dedirect URL + # for programmatic SAML users today. So we error our here and request + # users to use gcloud to complete a login. + raise exceptions.ReauthSamlChallengeFailError(SAML_CHALLENGE_MESSAGE) + + AVAILABLE_CHALLENGES = { challenge.name: challenge - for challenge in [SecurityKeyChallenge(), PasswordChallenge()] + for challenge in [SecurityKeyChallenge(), PasswordChallenge(), SamlChallenge()] } diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index 98fd71b04..e259f7825 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -54,6 +54,9 @@ class Credentials(credentials.ReadOnlyScoped, credentials.CredentialsWithQuotaPr credentials = credentials.with_quota_project('myproject-123) + Reauth is disabled by default. To enable reauth, set the + `enable_reauth_refresh` parameter to True in the constructor. Note that + reauth feature is intended for gcloud to use only. If reauth is enabled, `pyu2f` dependency has to be installed in order to use security key reauth feature. Dependency can be installed via `pip install pyu2f` or `pip install google-auth[reauth]`. @@ -73,6 +76,7 @@ def __init__( expiry=None, rapt_token=None, refresh_handler=None, + enable_reauth_refresh=False, ): """ Args: @@ -109,6 +113,8 @@ def __init__( refresh tokens are provided and tokens are obtained by calling some external process on demand. It is particularly useful for retrieving downscoped tokens from a token broker. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow + should be used. This flag is for gcloud to use only. """ super(Credentials, self).__init__() self.token = token @@ -123,6 +129,7 @@ def __init__( self._quota_project_id = quota_project_id self._rapt_token = rapt_token self.refresh_handler = refresh_handler + self._enable_reauth_refresh = enable_reauth_refresh def __getstate__(self): """A __getstate__ method must exist for the __setstate__ to be called @@ -151,6 +158,7 @@ def __setstate__(self, d): self._client_secret = d.get("_client_secret") self._quota_project_id = d.get("_quota_project_id") self._rapt_token = d.get("_rapt_token") + self._enable_reauth_refresh = d.get("_enable_reauth_refresh") # The refresh_handler setter should be used to repopulate this. self._refresh_handler = None @@ -241,6 +249,7 @@ def with_quota_project(self, quota_project_id): default_scopes=self.default_scopes, quota_project_id=quota_project_id, rapt_token=self.rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, ) @_helpers.copy_docstring(credentials.Credentials) @@ -296,6 +305,7 @@ def refresh(self, request): self._client_secret, scopes=scopes, rapt_token=self._rapt_token, + enable_reauth_refresh=self._enable_reauth_refresh, ) self.token = access_token @@ -366,6 +376,7 @@ def from_authorized_user_info(cls, info, scopes=None): client_secret=info.get("client_secret"), quota_project_id=info.get("quota_project_id"), # may not exist expiry=expiry, + rapt_token=info.get("rapt_token"), # may not exist ) @classmethod diff --git a/google/oauth2/reauth.py b/google/oauth2/reauth.py index fc2629e82..1e496d12e 100644 --- a/google/oauth2/reauth.py +++ b/google/oauth2/reauth.py @@ -275,6 +275,7 @@ def refresh_grant( client_secret, scopes=None, rapt_token=None, + enable_reauth_refresh=False, ): """Implements the reauthentication flow. @@ -292,6 +293,9 @@ def refresh_grant( token has a wild card scope (e.g. 'https://www.googleapis.com/auth/any-api'). rapt_token (Optional(str)): The rapt token for reauth. + enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow + should be used. The default value is False. This option is for + gcloud only, other users should use the default value. Returns: Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The @@ -324,6 +328,11 @@ def refresh_grant( or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED ) ): + if not enable_reauth_refresh: + raise exceptions.RefreshError( + "Reauthentication is needed. Please run `gcloud auth login --update-adc` to reauthenticate." + ) + rapt_token = get_rapt_token( request, client_id, client_secret, refresh_token, token_uri, scopes=scopes ) diff --git a/tests/data/authorized_user_with_rapt_token.json b/tests/data/authorized_user_with_rapt_token.json new file mode 100644 index 000000000..64b161d42 --- /dev/null +++ b/tests/data/authorized_user_with_rapt_token.json @@ -0,0 +1,8 @@ +{ + "client_id": "123", + "client_secret": "secret", + "refresh_token": "alabalaportocala", + "type": "authorized_user", + "rapt_token": "rapt" + } + \ No newline at end of file diff --git a/tests/oauth2/test_challenges.py b/tests/oauth2/test_challenges.py index 019b908da..412895ada 100644 --- a/tests/oauth2/test_challenges.py +++ b/tests/oauth2/test_challenges.py @@ -130,3 +130,11 @@ def test_password_challenge(getpass_mock): assert challenges.PasswordChallenge().obtain_challenge_input({}) == { "credential": " " } + + +def test_saml_challenge(): + challenge = challenges.SamlChallenge() + assert challenge.is_locally_eligible + assert challenge.name == "SAML" + with pytest.raises(exceptions.ReauthSamlChallengeFailError): + challenge.obtain_challenge_input(None) diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index 4a7f66e7f..b6a80e3d0 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -51,6 +51,7 @@ def make_credentials(cls): client_id=cls.CLIENT_ID, client_secret=cls.CLIENT_SECRET, rapt_token=cls.RAPT_TOKEN, + enable_reauth_refresh=True, ) def test_default_state(self): @@ -149,6 +150,7 @@ def test_refresh_success(self, unused_utcnow, refresh_grant): self.CLIENT_SECRET, None, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -219,6 +221,7 @@ def test_refresh_with_refresh_token_and_refresh_handler( self.CLIENT_SECRET, None, self.RAPT_TOKEN, + False, ) # Check that the credentials have the token and expiry @@ -422,6 +425,7 @@ def test_credentials_with_scopes_requested_refresh_success( scopes=scopes, default_scopes=default_scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -436,6 +440,7 @@ def test_credentials_with_scopes_requested_refresh_success( self.CLIENT_SECRET, scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -484,6 +489,7 @@ def test_credentials_with_only_default_scopes_requested( client_secret=self.CLIENT_SECRET, default_scopes=default_scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -498,6 +504,7 @@ def test_credentials_with_only_default_scopes_requested( self.CLIENT_SECRET, default_scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -549,6 +556,7 @@ def test_credentials_with_scopes_returned_refresh_success( client_secret=self.CLIENT_SECRET, scopes=scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -563,6 +571,7 @@ def test_credentials_with_scopes_returned_refresh_success( self.CLIENT_SECRET, scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -615,6 +624,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( client_secret=self.CLIENT_SECRET, scopes=scopes, rapt_token=self.RAPT_TOKEN, + enable_reauth_refresh=True, ) # Refresh credentials @@ -632,6 +642,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error( self.CLIENT_SECRET, scopes, self.RAPT_TOKEN, + True, ) # Check that the credentials have the token and expiry @@ -731,6 +742,7 @@ def test_from_authorized_user_file(self): assert creds.refresh_token == info["refresh_token"] assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT assert creds.scopes is None + assert creds.rapt_token is None scopes = ["email", "profile"] creds = credentials.Credentials.from_authorized_user_file( @@ -742,6 +754,18 @@ def test_from_authorized_user_file(self): assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT assert creds.scopes == scopes + def test_from_authorized_user_file_with_rapt_token(self): + info = AUTH_USER_INFO.copy() + file_path = os.path.join(DATA_DIR, "authorized_user_with_rapt_token.json") + + creds = credentials.Credentials.from_authorized_user_file(file_path) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes is None + assert creds.rapt_token == "rapt" + def test_to_json(self): info = AUTH_USER_INFO.copy() expiry = datetime.datetime(2020, 8, 14, 15, 54, 1) diff --git a/tests/oauth2/test_reauth.py b/tests/oauth2/test_reauth.py index e9ffa8a79..58d649d83 100644 --- a/tests/oauth2/test_reauth.py +++ b/tests/oauth2/test_reauth.py @@ -270,6 +270,7 @@ def test_refresh_grant_failed(): "client_secret", scopes=["foo", "bar"], rapt_token="rapt_token", + enable_reauth_refresh=True, ) assert excinfo.match(r"Bad request") mock_token_request.assert_called_with( @@ -298,7 +299,12 @@ def test_refresh_grant_success(): "google.oauth2.reauth.get_rapt_token", return_value="new_rapt_token" ): assert reauth.refresh_grant( - MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + MOCK_REQUEST, + "token_uri", + "refresh_token", + "client_id", + "client_secret", + enable_reauth_refresh=True, ) == ( "access_token", "refresh_token", @@ -306,3 +312,18 @@ def test_refresh_grant_success(): {"access_token": "access_token"}, "new_rapt_token", ) + + +def test_refresh_grant_reauth_refresh_disabled(): + with mock.patch( + "google.oauth2._client._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.side_effect = [ + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), + (True, {"access_token": "access_token"}), + ] + with pytest.raises(exceptions.RefreshError) as excinfo: + reauth.refresh_grant( + MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + ) + assert excinfo.match(r"Reauthentication is needed") diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py index 99cf16f80..bc89392ad 100644 --- a/tests_async/oauth2/test_credentials_async.py +++ b/tests_async/oauth2/test_credentials_async.py @@ -43,6 +43,7 @@ def make_credentials(cls): token_uri=cls.TOKEN_URI, client_id=cls.CLIENT_ID, client_secret=cls.CLIENT_SECRET, + enable_reauth_refresh=True, ) def test_default_state(self): @@ -97,6 +98,7 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant): self.CLIENT_SECRET, None, None, + True, ) # Check that the credentials have the token and expiry @@ -169,6 +171,7 @@ async def test_credentials_with_scopes_requested_refresh_success( self.CLIENT_SECRET, scopes, "old_rapt_token", + False, ) # Check that the credentials have the token and expiry @@ -231,6 +234,7 @@ async def test_credentials_with_scopes_returned_refresh_success( self.CLIENT_SECRET, scopes, None, + False, ) # Check that the credentials have the token and expiry @@ -301,6 +305,7 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( self.CLIENT_SECRET, scopes, None, + False, ) # Check that the credentials have the token and expiry diff --git a/tests_async/oauth2/test_reauth_async.py b/tests_async/oauth2/test_reauth_async.py index f144d89f5..d982e13a1 100644 --- a/tests_async/oauth2/test_reauth_async.py +++ b/tests_async/oauth2/test_reauth_async.py @@ -318,7 +318,12 @@ async def test_refresh_grant_success(): "google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token" ): assert await _reauth_async.refresh_grant( - MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + MOCK_REQUEST, + "token_uri", + "refresh_token", + "client_id", + "client_secret", + enable_reauth_refresh=True, ) == ( "access_token", "refresh_token", @@ -326,3 +331,19 @@ async def test_refresh_grant_success(): {"access_token": "access_token"}, "new_rapt_token", ) + + +@pytest.mark.asyncio +async def test_refresh_grant_reauth_refresh_disabled(): + with mock.patch( + "google.oauth2._client_async._token_endpoint_request_no_throw" + ) as mock_token_request: + mock_token_request.side_effect = [ + (False, {"error": "invalid_grant", "error_subtype": "rapt_required"}), + (True, {"access_token": "access_token"}), + ] + with pytest.raises(exceptions.RefreshError) as excinfo: + assert await _reauth_async.refresh_grant( + MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret" + ) + assert excinfo.match(r"Reauthentication is needed")