Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Mar 30, 2021
1 parent 9e87642 commit d2a3660
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 79 deletions.
2 changes: 1 addition & 1 deletion google/auth/exceptions.py
Expand Up @@ -55,5 +55,5 @@ class ReauthFailError(RefreshError):

def __init__(self, message=None):
super(ReauthFailError, self).__init__(
"Reauthentication challenge failed. {0}".format(message)
"Reauthentication failed. {0}".format(message)
)
47 changes: 32 additions & 15 deletions google/oauth2/_client.py
Expand Up @@ -45,18 +45,14 @@


def _handle_error_response(response_data):
""" "Translates an error response into an exception.
"""Translates an error response into an exception.
Args:
response_data (Mapping): The decoded response data.
Raises:
google.auth.exceptions.RefreshError
google.auth.exceptions.RefreshError: The errors contained in response_data.
"""
if not isinstance(response_data, Mapping):
raise exceptions.RefreshError(
f"response_data is a mapping object: '{response_data}'"
)
try:
error_details = "{}: {}".format(
response_data["error"], response_data.get("error_description")
Expand Down Expand Up @@ -86,24 +82,25 @@ def _parse_expiry(response_data):
return None


def _token_endpoint_request_no_error_check(
def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
successful, and a mapping for the JSON-decoded response data.
"""
if use_json:
headers = {"Content-Type": "application/json"}
Expand Down Expand Up @@ -158,6 +155,9 @@ def _token_endpoint_request(
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Expand All @@ -166,7 +166,7 @@ def _token_endpoint_request(
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
response_status_ok, response_data = _token_endpoint_request_no_error_check(
response_status_ok, response_data = _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
)
if not response_status_ok:
Expand Down Expand Up @@ -253,6 +253,22 @@ def id_token_jwt_grant(request, token_uri, assertion):


def _handle_refresh_grant_response(response_data, refresh_token):
"""Extract tokens from refresh grant response.
Args:
response_data (Mapping[str, str]): Refresh grant response data.
refresh_token (str): Current refresh token.
Returns:
Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access token,
refresh token, expiration, and additional data returned by the token
endpoint. If response_data doesn't have refresh token, then the current
refresh token will be returned.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
try:
access_token = response_data["access_token"]
except KeyError as caught_exc:
Expand Down Expand Up @@ -291,10 +307,11 @@ def refresh_grant(
scopes must be authorized for the refresh token. Useful if refresh
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
rapt_token (Optional(str)): The reauth Proof Token.
Returns:
Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
access token, new refresh token, expiration, and additional data
Tuple[str, str, Optional[datetime], Mapping[str, str]]: The access
token, new or current refresh token, expiration, and additional data
returned by the token endpoint.
Raises:
Expand Down
7 changes: 5 additions & 2 deletions google/oauth2/challenges.py
@@ -1,4 +1,4 @@
# Copyright 2017 Google Inc. All rights reserved.
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

""" Challenges for reauthentication.
"""

import abc
import base64
import getpass
Expand Down Expand Up @@ -95,7 +98,7 @@ def obtain_challenge_input(self, metadata):
import pyu2f.errors
import pyu2f.model
except ImportError as e:
warnings.warn(
sys.stderr.write(
"pyu2f is missing. Please install pyu2f to use Security key reauth feature."
)
return None
Expand Down
5 changes: 2 additions & 3 deletions google/oauth2/credentials.py
Expand Up @@ -98,7 +98,7 @@ def __init__(
quota_project_id (Optional[str]): The project ID used for quota and billing.
This project may be different from the project used to
create the credentials.
rapt_token (Optional[str]): The rapt token.
rapt_token (Optional[str]): The reauth Proof Token.
"""
super(Credentials, self).__init__()
self.token = token
Expand Down Expand Up @@ -180,8 +180,7 @@ def requires_scopes(self):

@property
def rapt_token(self):
"""Optional[str]: The OAuth 2.0 authorization server's token endpoint
URI."""
"""Optional[str]: The reauth Proof Token."""
return self._rapt_token

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
Expand Down
125 changes: 69 additions & 56 deletions google/oauth2/reauth.py
@@ -1,4 +1,4 @@
# Copyright 2017 Google Inc. All rights reserved.
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,7 @@
``https://www.googleapis.com/auth/accounts.reauth``.
This module provides a high-level function for executing the Reauth process,
:func:`refresh_access_token`, and lower-level helpers for doing the individual
:func:`refresh_grant`, and lower-level helpers for doing the individual
steps of the reauth process.
Those steps are:
Expand All @@ -32,17 +32,12 @@
3. Refreshing the access token using the returned rapt token.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import sys

from google.auth import exceptions
from google.oauth2 import challenges
from google.oauth2 import _client
from six.moves import http_client
from six.moves import range


Expand All @@ -64,8 +59,8 @@ def _get_challenges(
"""Does initial request to reauth API to get the challenges.
Args:
request (Callable): callable to run http requests. Accepts uri,
method, body and headers. Returns a tuple: (response, content)
request (google.auth.transport.Request): A callable used to make
HTTP requests.
supported_challenge_types (Sequence[str]): list of challenge names
supported by the manager.
access_token (str): Access token with reauth scopes.
Expand All @@ -89,8 +84,8 @@ def _send_challenge_result(
"""Attempt to refresh access token by sending next challenge result.
Args:
request (Callable): callable to run http requests. Accepts uri,
method, body and headers. Returns a tuple: (response, content)
request (google.auth.transport.Request): A callable used to make
HTTP requests.
session_id (str): session id returned by the initial reauth call.
challenge_id (str): challenge id returned by the initial reauth call.
client_input: dict with a challenge-specific client input. For example:
Expand Down Expand Up @@ -124,13 +119,15 @@ def _run_next_challenge(msg, request, access_token):
https://reauth.googleapis.com/v2/sessions:start or from sending the
previous challenge response to
https://reauth.googleapis.com/v2/sessions/id:continue)
request: callable to run http requests. Accepts uri, method, body
and headers. Returns a tuple: (response, content)
request (google.auth.transport.Request): A callable used to make
HTTP requests.
access_token: reauth access token
Returns: rapt token.
Returns:
dict: The response from the reauth API.
Raises:
google.auth.exceptions.ReauthError if reauth failed
google.auth.exceptions.ReauthError: if reauth failed.
"""
for challenge in msg["challenges"]:
if challenge["status"] != "READY":
Expand Down Expand Up @@ -167,18 +164,20 @@ def _obtain_rapt(request, access_token, requested_scopes, rounds_num=5):
"""Given an http request method and reauth access token, get rapt token.
Args:
request: callable to run http requests. Accepts uri, method, body
and headers. Returns a tuple: (response, content)
request (google.auth.transport.Request): A callable used to make
HTTP requests.
access_token: reauth access token
requested_scopes: scopes required by the client application
rounds_num: max number of attempts to get a rapt after the next
challenge, before failing the reauth. This defines total number of
challenges + number of additional retries if the chalenge input
wasn't accepted.
Returns: rapt token.
Returns:
str: The rapt token.
Raises:
google.auth.exceptions.ReauthError if reauth failed
google.auth.exceptions.ReauthError: if reauth failed
"""
msg = None

Expand Down Expand Up @@ -226,8 +225,8 @@ def get_rapt_token(
"""Given an http request method and refresh_token, get rapt token.
Args:
request: callable to run http requests. Accepts uri, method, body
and headers. Returns a tuple: (response, content)
request (google.auth.transport.Request): A callable used to make
HTTP requests.
client_id: client id to get access token for reauth scope.
client_secret: client secret for the client_id
refresh_token: refresh token to refresh access token
Expand All @@ -240,20 +239,24 @@ def get_rapt_token(
"""
sys.stderr.write("Reauthentication required.\n")

# Get access token for reauth.
access_token, _, _, _ = _client.refresh_grant(
request=request,
client_id=client_id,
client_secret=client_secret,
refresh_token=refresh_token,
token_uri=token_uri,
scopes=[_REAUTH_SCOPE],
)
try:
# Get access token for reauth.
access_token, _, _, _ = _client.refresh_grant(
request=request,
client_id=client_id,
client_secret=client_secret,
refresh_token=refresh_token,
token_uri=token_uri,
scopes=[_REAUTH_SCOPE],
)

# Get rapt token from reauth API.
rapt_token = _obtain_rapt(request, access_token, requested_scopes=scopes)
# Get rapt token from reauth API.
rapt_token = _obtain_rapt(request, access_token, requested_scopes=scopes)

return rapt_token
return rapt_token
except exceptions.RefreshError as e:
# TODO: convert refresh error to reauth error?
raise e


def refresh_grant(
Expand Down Expand Up @@ -305,28 +308,38 @@ def refresh_grant(
if rapt_token:
body["rapt"] = rapt_token

response_status_ok, response_data = _client._token_endpoint_request_no_error_check(
request, token_uri, body
)
if (
not response_status_ok
and response_data.get("error") == _REAUTH_NEEDED_ERROR
and (
response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_INVALID_RAPT
or response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_RAPT_REQUIRED
try:
response_status_ok, response_data = _client._token_endpoint_request_no_throw(
request, token_uri, body
)
):
rapt_token = get_rapt_token(
request, client_id, client_secret, refresh_token, token_uri, scopes=scopes
if (
not response_status_ok
and response_data.get("error") == _REAUTH_NEEDED_ERROR
and (
response_data.get("error_subtype") == _REAUTH_NEEDED_ERROR_INVALID_RAPT
or response_data.get("error_subtype")
== _REAUTH_NEEDED_ERROR_RAPT_REQUIRED
)
):
rapt_token = get_rapt_token(
request,
client_id,
client_secret,
refresh_token,
token_uri,
scopes=scopes,
)
body["rapt"] = rapt_token
(
response_status_ok,
response_data,
) = _client._token_endpoint_request_no_throw(request, token_uri, body)

if not response_status_ok:
_client._handle_error_response(response_data)
return _client._handle_refresh_grant_response(response_data, refresh_token) + (
rapt_token,
)
body["rapt"] = rapt_token
(
response_status_ok,
response_data,
) = _client._token_endpoint_request_no_error_check(request, token_uri, body)

if not response_status_ok:
_client._handle_error_response(response_data)
return _client._handle_refresh_grant_response(response_data, refresh_token) + (
rapt_token,
)
except exceptions.RefreshError as e:
# TODO: convert to reauth error
raise e
1 change: 1 addition & 0 deletions sample.py
Expand Up @@ -15,3 +15,4 @@
creds._scopes = scopes

creds.refresh(req)
creds.refresh(req)
4 changes: 2 additions & 2 deletions tests/oauth2/test__client.py
Expand Up @@ -57,12 +57,12 @@ def test__handle_error_response():


def test__handle_error_response_non_json():
response_data = "Help, I'm alive"
response_data = {"foo": "bar"}

with pytest.raises(exceptions.RefreshError) as excinfo:
_client._handle_error_response(response_data)

assert excinfo.match(r"Help, I\'m alive")
assert excinfo.match(r"{\"foo\": \"bar\"}")


@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min)
Expand Down

0 comments on commit d2a3660

Please sign in to comment.