Skip to content

Commit

Permalink
feat: add reauth support to async user credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 committed Apr 22, 2021
1 parent 36e6f0f commit 3a44dff
Show file tree
Hide file tree
Showing 7 changed files with 785 additions and 110 deletions.
125 changes: 62 additions & 63 deletions google/oauth2/_client_async.py
Expand Up @@ -30,70 +30,40 @@
from six.moves import http_client
from six.moves import urllib

from google.auth import _helpers
from google.auth import exceptions
from google.auth import jwt
from google.oauth2 import _client as client


def _handle_error_response(response_body):
""""Translates an error response into an exception.
Args:
response_body (str): The decoded response data.
Raises:
google.auth.exceptions.RefreshError
"""
try:
error_data = json.loads(response_body)
error_details = "{}: {}".format(
error_data["error"], error_data.get("error_description")
)
# If no details could be extracted, use the response data.
except (KeyError, ValueError):
error_details = response_body

raise exceptions.RefreshError(error_details, response_body)


def _parse_expiry(response_data):
"""Parses the expiry field from a response into a datetime.
Args:
response_data (Mapping): The JSON-parsed response data.
Returns:
Optional[datetime]: The expiration or ``None`` if no expiration was
specified.
"""
expires_in = response_data.get("expires_in", None)

if expires_in is not None:
return _helpers.utcnow() + datetime.timedelta(seconds=expires_in)
else:
return None


async def _token_endpoint_request(request, token_uri, body):
async def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
successful, and a mapping for the JSON-decoded response data.
"""
body = urllib.parse.urlencode(body).encode("utf-8")
headers = {"content-type": client._URLENCODED_CONTENT_TYPE}
if use_json:
headers = {"Content-Type": client._JSON_CONTENT_TYPE}
body = json.dumps(body).encode("utf-8")
else:
headers = {"Content-Type": client._URLENCODED_CONTENT_TYPE}
body = urllib.parse.urlencode(body).encode("utf-8")

if access_token:
headers["Authorization"] = "Bearer {}".format(access_token)

retry = 0
# retry to fetch token for maximum of two times if any internal failure
Expand Down Expand Up @@ -126,8 +96,38 @@ async def _token_endpoint_request(request, token_uri, body):
):
retry += 1
continue
_handle_error_response(response_body)
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data


async def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
response_status_ok, response_data = await _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
)
if not response_status_ok:
client._handle_error_response(response_data)
return response_data


Expand Down Expand Up @@ -163,7 +163,7 @@ async def jwt_grant(request, token_uri, assertion):
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

expiry = _parse_expiry(response_data)
expiry = client._parse_expiry(response_data)

return access_token, expiry, response_data

Expand Down Expand Up @@ -210,7 +210,13 @@ async def id_token_jwt_grant(request, token_uri, assertion):


async def refresh_grant(
request, token_uri, refresh_token, client_id, client_secret, scopes=None
request,
token_uri,
refresh_token,
client_id,
client_secret,
scopes=None,
rapt_token=None,
):
"""Implements the OAuth 2.0 refresh token grant.
Expand All @@ -229,10 +235,11 @@ async def refresh_grant(
scopes must be authorized for the refresh token. Useful if refresh
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
rapt_token (Optional(str)): The reauth Proof Token.
Returns:
Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
access token, new refresh token, expiration, and additional data
access token, new or current refresh token, expiration, and additional data
returned by the token endpoint.
Raises:
Expand All @@ -249,16 +256,8 @@ async def refresh_grant(
}
if scopes:
body["scope"] = " ".join(scopes)
if rapt_token:
body["rapt"] = rapt_token

response_data = await _token_endpoint_request(request, token_uri, body)

try:
access_token = response_data["access_token"]
except KeyError as caught_exc:
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

refresh_token = response_data.get("refresh_token", refresh_token)
expiry = _parse_expiry(response_data)

return access_token, refresh_token, expiry, response_data
return client._handle_refresh_grant_response(response_data, refresh_token)
13 changes: 8 additions & 5 deletions google/oauth2/_credentials_async.py
Expand Up @@ -34,7 +34,7 @@
from google.auth import _credentials_async as credentials
from google.auth import _helpers
from google.auth import exceptions
from google.oauth2 import _client_async as _client
from google.oauth2 import _reauth_async as reauth
from google.oauth2 import credentials as oauth2_credentials


Expand Down Expand Up @@ -66,23 +66,26 @@ async def refresh(self, request):
refresh_token,
expiry,
grant_response,
) = await _client.refresh_grant(
rapt_token,
) = await reauth.refresh_grant(
request,
self._token_uri,
self._refresh_token,
self._client_id,
self._client_secret,
self._scopes,
scopes=self._scopes,
rapt_token=self._rapt_token,
)

self.token = access_token
self.expiry = expiry
self._refresh_token = refresh_token
self._id_token = grant_response.get("id_token")
self._rapt_token = rapt_token

if self._scopes and "scopes" in grant_response:
if self._scopes and "scope" in grant_response:
requested_scopes = frozenset(self._scopes)
granted_scopes = frozenset(grant_response["scopes"].split())
granted_scopes = frozenset(grant_response["scope"].split())
scopes_requested_but_not_granted = requested_scopes - granted_scopes
if scopes_requested_but_not_granted:
raise exceptions.RefreshError(
Expand Down

0 comments on commit 3a44dff

Please sign in to comment.