Skip to content

Commit

Permalink
add enable_reauth_refresh flag
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Jul 30, 2021
1 parent c32f77b commit 9d01a79
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 3 deletions.
1 change: 1 addition & 0 deletions google/oauth2/_credentials_async.py
Expand Up @@ -75,6 +75,7 @@ async def refresh(self, request):
self._client_secret,
scopes=self._scopes,
rapt_token=self._rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
)

self.token = access_token
Expand Down
8 changes: 8 additions & 0 deletions google/oauth2/_reauth_async.py
Expand Up @@ -250,6 +250,7 @@ async def refresh_grant(
client_secret,
scopes=None,
rapt_token=None,
enable_reauth_refresh=False,
):
"""Implements the reauthentication flow.
Expand All @@ -267,6 +268,8 @@ async def refresh_grant(
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
rapt_token (Optional(str)): The rapt token for reauth.
enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow should
be used. The default value is False.
Returns:
Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The
Expand Down Expand Up @@ -301,6 +304,11 @@ async def refresh_grant(
== reauth._REAUTH_NEEDED_ERROR_RAPT_REQUIRED
)
):
if not enable_reauth_refresh:
raise exceptions.RefreshError(
"Reauthenticatio is needed. Please run `gcloud auth application-default login` to reauthentciate."
)

rapt_token = await get_rapt_token(
request, client_id, client_secret, refresh_token, token_uri, scopes=scopes
)
Expand Down
7 changes: 7 additions & 0 deletions google/oauth2/credentials.py
Expand Up @@ -75,6 +75,7 @@ def __init__(
expiry=None,
rapt_token=None,
refresh_handler=None,
enable_reauth_refresh=False,
):
"""
Args:
Expand Down Expand Up @@ -111,6 +112,8 @@ def __init__(
refresh tokens are provided and tokens are obtained by calling
some external process on demand. It is particularly useful for
retrieving downscoped tokens from a token broker.
enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow should
be used. The default value is False.
"""
super(Credentials, self).__init__()
self.token = token
Expand All @@ -125,6 +128,7 @@ def __init__(
self._quota_project_id = quota_project_id
self._rapt_token = rapt_token
self.refresh_handler = refresh_handler
self._enable_reauth_refresh = enable_reauth_refresh

def __getstate__(self):
"""A __getstate__ method must exist for the __setstate__ to be called
Expand Down Expand Up @@ -153,6 +157,7 @@ def __setstate__(self, d):
self._client_secret = d.get("_client_secret")
self._quota_project_id = d.get("_quota_project_id")
self._rapt_token = d.get("_rapt_token")
self._enable_reauth_refresh = d.get("_enable_reauth_refresh")
# The refresh_handler setter should be used to repopulate this.
self._refresh_handler = None

Expand Down Expand Up @@ -243,6 +248,7 @@ def with_quota_project(self, quota_project_id):
default_scopes=self.default_scopes,
quota_project_id=quota_project_id,
rapt_token=self.rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
)

@_helpers.copy_docstring(credentials.Credentials)
Expand Down Expand Up @@ -298,6 +304,7 @@ def refresh(self, request):
self._client_secret,
scopes=scopes,
rapt_token=self._rapt_token,
enable_reauth_refresh=self._enable_reauth_refresh,
)

self.token = access_token
Expand Down
10 changes: 9 additions & 1 deletion google/oauth2/reauth.py
Expand Up @@ -277,6 +277,7 @@ def refresh_grant(
client_secret,
scopes=None,
rapt_token=None,
enable_reauth_refresh=False,
):
"""Implements the reauthentication flow.
Expand All @@ -294,6 +295,8 @@ def refresh_grant(
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
rapt_token (Optional(str)): The rapt token for reauth.
enable_reauth_refresh (Optional[bool]): Whether reauth refresh flow should
be used. The default value is False.
Returns:
Tuple[str, Optional[str], Optional[datetime], Mapping[str, str], str]: The
Expand All @@ -312,7 +315,7 @@ def refresh_grant(
}
if scopes:
body["scope"] = " ".join(scopes)
if rapt_token:
if rapt_token and enable_reauth_refresh:
body["rapt"] = rapt_token

response_status_ok, response_data = _client._token_endpoint_request_no_throw(
Expand All @@ -326,6 +329,11 @@ def refresh_grant(
or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED
)
):
if not enable_reauth_refresh:
raise exceptions.RefreshError(
"Reauthenticatio is needed. Please run `gcloud auth application-default login` to reauthentciate."
)

rapt_token = get_rapt_token(
request, client_id, client_secret, refresh_token, token_uri, scopes=scopes
)
Expand Down
11 changes: 11 additions & 0 deletions tests/oauth2/test_credentials.py
Expand Up @@ -51,6 +51,7 @@ def make_credentials(cls):
client_id=cls.CLIENT_ID,
client_secret=cls.CLIENT_SECRET,
rapt_token=cls.RAPT_TOKEN,
enable_reauth_refresh=True,
)

def test_default_state(self):
Expand Down Expand Up @@ -149,6 +150,7 @@ def test_refresh_success(self, unused_utcnow, refresh_grant):
self.CLIENT_SECRET,
None,
self.RAPT_TOKEN,
True,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -219,6 +221,7 @@ def test_refresh_with_refresh_token_and_refresh_handler(
self.CLIENT_SECRET,
None,
self.RAPT_TOKEN,
False,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -422,6 +425,7 @@ def test_credentials_with_scopes_requested_refresh_success(
scopes=scopes,
default_scopes=default_scopes,
rapt_token=self.RAPT_TOKEN,
enable_reauth_refresh=True,
)

# Refresh credentials
Expand All @@ -436,6 +440,7 @@ def test_credentials_with_scopes_requested_refresh_success(
self.CLIENT_SECRET,
scopes,
self.RAPT_TOKEN,
True,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -484,6 +489,7 @@ def test_credentials_with_only_default_scopes_requested(
client_secret=self.CLIENT_SECRET,
default_scopes=default_scopes,
rapt_token=self.RAPT_TOKEN,
enable_reauth_refresh=True,
)

# Refresh credentials
Expand All @@ -498,6 +504,7 @@ def test_credentials_with_only_default_scopes_requested(
self.CLIENT_SECRET,
default_scopes,
self.RAPT_TOKEN,
True,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -549,6 +556,7 @@ def test_credentials_with_scopes_returned_refresh_success(
client_secret=self.CLIENT_SECRET,
scopes=scopes,
rapt_token=self.RAPT_TOKEN,
enable_reauth_refresh=True,
)

# Refresh credentials
Expand All @@ -563,6 +571,7 @@ def test_credentials_with_scopes_returned_refresh_success(
self.CLIENT_SECRET,
scopes,
self.RAPT_TOKEN,
True,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -615,6 +624,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error(
client_secret=self.CLIENT_SECRET,
scopes=scopes,
rapt_token=self.RAPT_TOKEN,
enable_reauth_refresh=True,
)

# Refresh credentials
Expand All @@ -632,6 +642,7 @@ def test_credentials_with_scopes_refresh_failure_raises_refresh_error(
self.CLIENT_SECRET,
scopes,
self.RAPT_TOKEN,
True,
)

# Check that the credentials have the token and expiry
Expand Down
23 changes: 22 additions & 1 deletion tests/oauth2/test_reauth.py
Expand Up @@ -270,6 +270,7 @@ def test_refresh_grant_failed():
"client_secret",
scopes=["foo", "bar"],
rapt_token="rapt_token",
enable_reauth_refresh=True,
)
assert excinfo.match(r"Bad request")
mock_token_request.assert_called_with(
Expand Down Expand Up @@ -298,11 +299,31 @@ def test_refresh_grant_success():
"google.oauth2.reauth.get_rapt_token", return_value="new_rapt_token"
):
assert reauth.refresh_grant(
MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
MOCK_REQUEST,
"token_uri",
"refresh_token",
"client_id",
"client_secret",
enable_reauth_refresh=True,
) == (
"access_token",
"refresh_token",
None,
{"access_token": "access_token"},
"new_rapt_token",
)


def test_refresh_grant_reauth_refresh_disabled():
with mock.patch(
"google.oauth2._client._token_endpoint_request_no_throw"
) as mock_token_request:
mock_token_request.side_effect = [
(False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
(True, {"access_token": "access_token"}),
]
with pytest.raises(exceptions.RefreshError) as excinfo:
reauth.refresh_grant(
MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
)
assert excinfo.match(r"Reauthenticatio is needed")
5 changes: 5 additions & 0 deletions tests_async/oauth2/test_credentials_async.py
Expand Up @@ -43,6 +43,7 @@ def make_credentials(cls):
token_uri=cls.TOKEN_URI,
client_id=cls.CLIENT_ID,
client_secret=cls.CLIENT_SECRET,
enable_reauth_refresh=True,
)

def test_default_state(self):
Expand Down Expand Up @@ -97,6 +98,7 @@ async def test_refresh_success(self, unused_utcnow, refresh_grant):
self.CLIENT_SECRET,
None,
None,
True,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -169,6 +171,7 @@ async def test_credentials_with_scopes_requested_refresh_success(
self.CLIENT_SECRET,
scopes,
"old_rapt_token",
False,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -231,6 +234,7 @@ async def test_credentials_with_scopes_returned_refresh_success(
self.CLIENT_SECRET,
scopes,
None,
False,
)

# Check that the credentials have the token and expiry
Expand Down Expand Up @@ -301,6 +305,7 @@ async def test_credentials_with_scopes_refresh_failure_raises_refresh_error(
self.CLIENT_SECRET,
scopes,
None,
False,
)

# Check that the credentials have the token and expiry
Expand Down
23 changes: 22 additions & 1 deletion tests_async/oauth2/test_reauth_async.py
Expand Up @@ -318,11 +318,32 @@ async def test_refresh_grant_success():
"google.oauth2._reauth_async.get_rapt_token", return_value="new_rapt_token"
):
assert await _reauth_async.refresh_grant(
MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
MOCK_REQUEST,
"token_uri",
"refresh_token",
"client_id",
"client_secret",
enable_reauth_refresh=True,
) == (
"access_token",
"refresh_token",
None,
{"access_token": "access_token"},
"new_rapt_token",
)


@pytest.mark.asyncio
async def test_refresh_grant_reauth_refresh_disabled():
with mock.patch(
"google.oauth2._client_async._token_endpoint_request_no_throw"
) as mock_token_request:
mock_token_request.side_effect = [
(False, {"error": "invalid_grant", "error_subtype": "rapt_required"}),
(True, {"access_token": "access_token"}),
]
with pytest.raises(exceptions.RefreshError) as excinfo:
assert await _reauth_async.refresh_grant(
MOCK_REQUEST, "token_uri", "refresh_token", "client_id", "client_secret"
)
assert excinfo.match(r"Reauthenticatio is needed")

0 comments on commit 9d01a79

Please sign in to comment.