Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add reauth support to async user credentials #738

Merged
merged 1 commit into from Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
125 changes: 62 additions & 63 deletions google/oauth2/_client_async.py
Expand Up @@ -30,70 +30,40 @@
from six.moves import http_client
from six.moves import urllib

from google.auth import _helpers
from google.auth import exceptions
from google.auth import jwt
from google.oauth2 import _client as client


def _handle_error_response(response_body):
""""Translates an error response into an exception.
Args:
response_body (str): The decoded response data.
Raises:
google.auth.exceptions.RefreshError
"""
try:
error_data = json.loads(response_body)
error_details = "{}: {}".format(
error_data["error"], error_data.get("error_description")
)
# If no details could be extracted, use the response data.
except (KeyError, ValueError):
error_details = response_body

raise exceptions.RefreshError(error_details, response_body)


def _parse_expiry(response_data):
"""Parses the expiry field from a response into a datetime.
Args:
response_data (Mapping): The JSON-parsed response data.
Returns:
Optional[datetime]: The expiration or ``None`` if no expiration was
specified.
"""
expires_in = response_data.get("expires_in", None)

if expires_in is not None:
return _helpers.utcnow() + datetime.timedelta(seconds=expires_in)
else:
return None


async def _token_endpoint_request(request, token_uri, body):
async def _token_endpoint_request_no_throw(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
This function doesn't throw on response errors.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
Tuple(bool, Mapping[str, str]): A boolean indicating if the request is
successful, and a mapping for the JSON-decoded response data.
"""
body = urllib.parse.urlencode(body).encode("utf-8")
headers = {"content-type": client._URLENCODED_CONTENT_TYPE}
if use_json:
headers = {"Content-Type": client._JSON_CONTENT_TYPE}
body = json.dumps(body).encode("utf-8")
else:
headers = {"Content-Type": client._URLENCODED_CONTENT_TYPE}
body = urllib.parse.urlencode(body).encode("utf-8")

if access_token:
headers["Authorization"] = "Bearer {}".format(access_token)

retry = 0
# retry to fetch token for maximum of two times if any internal failure
Expand Down Expand Up @@ -126,8 +96,38 @@ async def _token_endpoint_request(request, token_uri, body):
):
retry += 1
continue
_handle_error_response(response_body)
return response.status == http_client.OK, response_data

return response.status == http_client.OK, response_data


async def _token_endpoint_request(
request, token_uri, body, access_token=None, use_json=False
):
"""Makes a request to the OAuth 2.0 authorization server's token endpoint.
Args:
request (google.auth.transport.Request): A callable used to make
HTTP requests.
token_uri (str): The OAuth 2.0 authorizations server's token endpoint
URI.
body (Mapping[str, str]): The parameters to send in the request body.
access_token (Optional(str)): The access token needed to make the request.
use_json (Optional(bool)): Use urlencoded format or json format for the
content type. The default value is False.
Returns:
Mapping[str, str]: The JSON-decoded response data.
Raises:
google.auth.exceptions.RefreshError: If the token endpoint returned
an error.
"""
response_status_ok, response_data = await _token_endpoint_request_no_throw(
request, token_uri, body, access_token=access_token, use_json=use_json
)
if not response_status_ok:
client._handle_error_response(response_data)
return response_data


Expand Down Expand Up @@ -163,7 +163,7 @@ async def jwt_grant(request, token_uri, assertion):
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

expiry = _parse_expiry(response_data)
expiry = client._parse_expiry(response_data)

return access_token, expiry, response_data

Expand Down Expand Up @@ -210,7 +210,13 @@ async def id_token_jwt_grant(request, token_uri, assertion):


async def refresh_grant(
request, token_uri, refresh_token, client_id, client_secret, scopes=None
request,
token_uri,
refresh_token,
client_id,
client_secret,
scopes=None,
rapt_token=None,
):
"""Implements the OAuth 2.0 refresh token grant.
Expand All @@ -229,10 +235,11 @@ async def refresh_grant(
scopes must be authorized for the refresh token. Useful if refresh
token has a wild card scope (e.g.
'https://www.googleapis.com/auth/any-api').
rapt_token (Optional(str)): The reauth Proof Token.
Returns:
Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The
access token, new refresh token, expiration, and additional data
access token, new or current refresh token, expiration, and additional data
returned by the token endpoint.
Raises:
Expand All @@ -249,16 +256,8 @@ async def refresh_grant(
}
if scopes:
body["scope"] = " ".join(scopes)
if rapt_token:
body["rapt"] = rapt_token

response_data = await _token_endpoint_request(request, token_uri, body)

try:
access_token = response_data["access_token"]
except KeyError as caught_exc:
new_exc = exceptions.RefreshError("No access token in response.", response_data)
six.raise_from(new_exc, caught_exc)

refresh_token = response_data.get("refresh_token", refresh_token)
expiry = _parse_expiry(response_data)

return access_token, refresh_token, expiry, response_data
return client._handle_refresh_grant_response(response_data, refresh_token)
13 changes: 8 additions & 5 deletions google/oauth2/_credentials_async.py
Expand Up @@ -34,7 +34,7 @@
from google.auth import _credentials_async as credentials
from google.auth import _helpers
from google.auth import exceptions
from google.oauth2 import _client_async as _client
from google.oauth2 import _reauth_async as reauth
from google.oauth2 import credentials as oauth2_credentials


Expand Down Expand Up @@ -66,23 +66,26 @@ async def refresh(self, request):
refresh_token,
expiry,
grant_response,
) = await _client.refresh_grant(
rapt_token,
) = await reauth.refresh_grant(
request,
self._token_uri,
self._refresh_token,
self._client_id,
self._client_secret,
self._scopes,
scopes=self._scopes,
rapt_token=self._rapt_token,
)

self.token = access_token
self.expiry = expiry
self._refresh_token = refresh_token
self._id_token = grant_response.get("id_token")
self._rapt_token = rapt_token

if self._scopes and "scopes" in grant_response:
if self._scopes and "scope" in grant_response:
requested_scopes = frozenset(self._scopes)
granted_scopes = frozenset(grant_response["scopes"].split())
granted_scopes = frozenset(grant_response["scope"].split())
scopes_requested_but_not_granted = requested_scopes - granted_scopes
if scopes_requested_but_not_granted:
raise exceptions.RefreshError(
Expand Down