Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix expiry for to_json() #589

Merged
merged 7 commits into from Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 33 additions & 10 deletions google/oauth2/credentials.py
Expand Up @@ -31,6 +31,7 @@
.. _rfc6749 section 4.1: https://tools.ietf.org/html/rfc6749#section-4.1
"""

from datetime import datetime
import io
import json

Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
client_secret=None,
scopes=None,
quota_project_id=None,
expiry=None,
):
"""
Args:
Expand Down Expand Up @@ -95,6 +97,7 @@ def __init__(
"""
super(Credentials, self).__init__()
self.token = token
self.expiry = expiry
self._refresh_token = refresh_token
self._id_token = id_token
self._scopes = scopes
Expand Down Expand Up @@ -128,6 +131,11 @@ def refresh_token(self):
"""Optional[str]: The OAuth 2.0 refresh token."""
return self._refresh_token

@property
def scopes(self):
"""Optional[str]: The OAuth 2.0 permission scopes."""
return self._scopes

@property
def token_uri(self):
"""Optional[str]: The OAuth 2.0 authorization server's token endpoint
Expand Down Expand Up @@ -241,16 +249,30 @@ def from_authorized_user_info(cls, info, scopes=None):
"fields {}.".format(", ".join(missing))
)

# access token expiry (datetime obj); auto-expire if not saved
expiry = info.get("expiry")
if expiry:
expiry = datetime.strptime(
expiry.rstrip("Z").split(".")[0], "%Y-%m-%dT%H:%M:%S"
)
else:
expiry = _helpers.utcnow() - _helpers.CLOCK_SKEW

# process scopes, which needs to be a seq
if scopes is None and "scopes" in info:
scopes = info.get("scopes")
if isinstance(scopes, str):
scopes = scopes.split(" ")

return cls(
None, # No access token, must be refreshed.
refresh_token=info["refresh_token"],
token_uri=_GOOGLE_OAUTH2_TOKEN_ENDPOINT,
token=info.get("token"),
refresh_token=info.get("refresh_token"),
token_uri=_GOOGLE_OAUTH2_TOKEN_ENDPOINT, # always overrides
scopes=scopes,
client_id=info["client_id"],
client_secret=info["client_secret"],
quota_project_id=info.get(
"quota_project_id"
), # quota project may not exist
client_id=info.get("client_id"),
client_secret=info.get("client_secret"),
quota_project_id=info.get("quota_project_id"), # may not exist
expiry=expiry,
)

@classmethod
Expand Down Expand Up @@ -294,8 +316,10 @@ def to_json(self, strip=None):
"client_secret": self.client_secret,
"scopes": self.scopes,
}
if self.expiry: # flatten expiry timestamp
prep["expiry"] = self.expiry.isoformat() + "Z"

# Remove empty entries
# Remove empty entries (those which are None)
prep = {k: v for k, v in prep.items() if v is not None}

# Remove entries that explicitely need to be removed
Expand All @@ -316,7 +340,6 @@ class UserAccessTokenCredentials(credentials.CredentialsWithQuotaProject):
specified, the current active account will be used.
quota_project_id (Optional[str]): The project ID used for quota
and billing.

"""

def __init__(self, account=None, quota_project_id=None):
Expand Down
24 changes: 24 additions & 0 deletions tests/oauth2/test_credentials.py
Expand Up @@ -359,6 +359,20 @@ def test_from_authorized_user_info(self):
assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT
assert creds.scopes == scopes

info["scopes"] = "email" # single non-array scope from file
creds = credentials.Credentials.from_authorized_user_info(info)
assert creds.scopes == [info["scopes"]]

info["scopes"] = ["email", "profile"] # array scope from file
creds = credentials.Credentials.from_authorized_user_info(info)
assert creds.scopes == info["scopes"]

expiry = datetime.datetime(2020, 8, 14, 15, 54, 1)
info["expiry"] = expiry.isoformat() + "Z"
creds = credentials.Credentials.from_authorized_user_info(info)
assert creds.expiry == expiry
assert creds.expired

def test_from_authorized_user_file(self):
info = AUTH_USER_INFO.copy()

Expand All @@ -381,7 +395,10 @@ def test_from_authorized_user_file(self):

def test_to_json(self):
info = AUTH_USER_INFO.copy()
expiry = datetime.datetime(2020, 8, 14, 15, 54, 1)
info["expiry"] = expiry.isoformat() + "Z"
creds = credentials.Credentials.from_authorized_user_info(info)
assert creds.expiry == expiry

# Test with no `strip` arg
json_output = creds.to_json()
Expand All @@ -392,6 +409,7 @@ def test_to_json(self):
assert json_asdict.get("client_id") == creds.client_id
assert json_asdict.get("scopes") == creds.scopes
assert json_asdict.get("client_secret") == creds.client_secret
assert json_asdict.get("expiry") == info["expiry"]

# Test with a `strip` arg
json_output = creds.to_json(strip=["client_secret"])
Expand All @@ -403,6 +421,12 @@ def test_to_json(self):
assert json_asdict.get("scopes") == creds.scopes
assert json_asdict.get("client_secret") is None

# Test with no expiry
creds.expiry = None
json_output = creds.to_json()
json_asdict = json.loads(json_output)
assert json_asdict.get("expiry") is None

def test_pickle_and_unpickle(self):
creds = self.make_credentials()
unpickled = pickle.loads(pickle.dumps(creds))
Expand Down