diff --git a/google/oauth2/credentials.py b/google/oauth2/credentials.py index dcfa5f912..158249ed5 100644 --- a/google/oauth2/credentials.py +++ b/google/oauth2/credentials.py @@ -74,6 +74,7 @@ def __init__( quota_project_id=None, expiry=None, rapt_token=None, + refresh_handler=None, ): """ Args: @@ -103,6 +104,13 @@ def __init__( This project may be different from the project used to create the credentials. rapt_token (Optional[str]): The reauth Proof Token. + refresh_handler (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]): + A callable which takes in the HTTP request callable and the list of + OAuth scopes and when called returns an access token string for the + requested scopes and its expiry datetime. This is useful when no + refresh tokens are provided and tokens are obtained by calling + some external process on demand. It is particularly useful for + retrieving downscoped tokens from a token broker. """ super(Credentials, self).__init__() self.token = token @@ -116,13 +124,20 @@ def __init__( self._client_secret = client_secret self._quota_project_id = quota_project_id self._rapt_token = rapt_token + self.refresh_handler = refresh_handler def __getstate__(self): """A __getstate__ method must exist for the __setstate__ to be called This is identical to the default implementation. See https://docs.python.org/3.7/library/pickle.html#object.__setstate__ """ - return self.__dict__ + state_dict = self.__dict__.copy() + # Remove _refresh_handler function as there are limitations pickling and + # unpickling certain callables (lambda, functools.partial instances) + # because they need to be importable. + # Instead, the refresh_handler setter should be used to repopulate this. + del state_dict["_refresh_handler"] + return state_dict def __setstate__(self, d): """Credentials pickled with older versions of the class do not have @@ -138,6 +153,8 @@ def __setstate__(self, d): self._client_secret = d.get("_client_secret") self._quota_project_id = d.get("_quota_project_id") self._rapt_token = d.get("_rapt_token") + # The refresh_handler setter should be used to repopulate this. + self._refresh_handler = None @property def refresh_token(self): @@ -187,6 +204,31 @@ def rapt_token(self): """Optional[str]: The reauth Proof Token.""" return self._rapt_token + @property + def refresh_handler(self): + """Returns the refresh handler if available. + + Returns: + Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]: + The current refresh handler. + """ + return self._refresh_handler + + @refresh_handler.setter + def refresh_handler(self, value): + """Updates the current refresh handler. + + Args: + value (Optional[Callable[[google.auth.transport.Request, Sequence[str]], [str, datetime]]]): + The updated value of the refresh handler. + + Raises: + TypeError: If the value is not a callable or None. + """ + if not callable(value) and value is not None: + raise TypeError("The provided refresh_handler is not a callable or None.") + self._refresh_handler = value + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): @@ -205,6 +247,31 @@ def with_quota_project(self, quota_project_id): @_helpers.copy_docstring(credentials.Credentials) def refresh(self, request): + scopes = self._scopes if self._scopes is not None else self._default_scopes + # Use refresh handler if available and no refresh token is + # available. This is useful in general when tokens are obtained by calling + # some external process on demand. It is particularly useful for retrieving + # downscoped tokens from a token broker. + if self._refresh_token is None and self.refresh_handler: + token, expiry = self.refresh_handler(request, scopes=scopes) + # Validate returned data. + if not isinstance(token, str): + raise exceptions.RefreshError( + "The refresh_handler returned token is not a string." + ) + if not isinstance(expiry, datetime): + raise exceptions.RefreshError( + "The refresh_handler returned expiry is not a datetime object." + ) + if _helpers.utcnow() >= expiry - _helpers.CLOCK_SKEW: + raise exceptions.RefreshError( + "The credentials returned by the refresh_handler are " + "already expired." + ) + self.token = token + self.expiry = expiry + return + if ( self._refresh_token is None or self._token_uri is None @@ -217,8 +284,6 @@ def refresh(self, request): "token_uri, client_id, and client_secret." ) - scopes = self._scopes if self._scopes is not None else self._default_scopes - ( access_token, refresh_token, diff --git a/tests/oauth2/test_credentials.py b/tests/oauth2/test_credentials.py index 4a387a58e..4a7f66e7f 100644 --- a/tests/oauth2/test_credentials.py +++ b/tests/oauth2/test_credentials.py @@ -66,6 +66,50 @@ def test_default_state(self): assert credentials.client_id == self.CLIENT_ID assert credentials.client_secret == self.CLIENT_SECRET assert credentials.rapt_token == self.RAPT_TOKEN + assert credentials.refresh_handler is None + + def test_refresh_handler_setter_and_getter(self): + scopes = ["email", "profile"] + original_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_1", None)) + updated_refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN_2", None)) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=None, + refresh_handler=original_refresh_handler, + ) + + assert creds.refresh_handler is original_refresh_handler + + creds.refresh_handler = updated_refresh_handler + + assert creds.refresh_handler is updated_refresh_handler + + creds.refresh_handler = None + + assert creds.refresh_handler is None + + def test_invalid_refresh_handler(self): + scopes = ["email", "profile"] + with pytest.raises(TypeError) as excinfo: + credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=None, + refresh_handler=object(), + ) + + assert excinfo.match("The provided refresh_handler is not a callable or None.") @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( @@ -126,6 +170,221 @@ def test_refresh_no_refresh_token(self): request.assert_not_called() + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + def test_refresh_with_refresh_token_and_refresh_handler( + self, unused_utcnow, refresh_grant + ): + token = "token" + new_rapt_token = "new_rapt_token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = {"id_token": mock.sentinel.id_token} + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + # rapt_token + new_rapt_token, + ) + + refresh_handler = mock.Mock() + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + rapt_token=self.RAPT_TOKEN, + refresh_handler=refresh_handler, + ) + + # Refresh credentials + creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + None, + self.RAPT_TOKEN, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.rapt_token == new_rapt_token + + # Check that the credentials are valid (have a token and are not + # expired) + assert creds.valid + + # Assert refresh handler not called as the refresh token has + # higher priority. + refresh_handler.assert_not_called() + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_success_scopes(self, unused_utcnow): + expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800) + refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + creds.refresh(request) + + assert creds.token == "ACCESS_TOKEN" + assert creds.expiry == expected_expiry + assert creds.valid + assert not creds.expired + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_success_default_scopes(self, unused_utcnow): + expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800) + original_refresh_handler = mock.Mock( + return_value=("UNUSED_TOKEN", expected_expiry) + ) + refresh_handler = mock.Mock(return_value=("ACCESS_TOKEN", expected_expiry)) + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=None, + default_scopes=default_scopes, + refresh_handler=original_refresh_handler, + ) + + # Test newly set refresh_handler is used instead of the original one. + creds.refresh_handler = refresh_handler + creds.refresh(request) + + assert creds.token == "ACCESS_TOKEN" + assert creds.expiry == expected_expiry + assert creds.valid + assert not creds.expired + # default_scopes should be used since no developer provided scopes + # are provided. + refresh_handler.assert_called_with(request, scopes=default_scopes) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_invalid_token(self, unused_utcnow): + expected_expiry = datetime.datetime.min + datetime.timedelta(seconds=2800) + # Simulate refresh handler does not return a valid token. + refresh_handler = mock.Mock(return_value=(None, expected_expiry)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + with pytest.raises( + exceptions.RefreshError, match="returned token is not a string" + ): + creds.refresh(request) + + assert creds.token is None + assert creds.expiry is None + assert not creds.valid + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + + def test_refresh_with_refresh_handler_invalid_expiry(self): + # Simulate refresh handler returns expiration time in an invalid unit. + refresh_handler = mock.Mock(return_value=("TOKEN", 2800)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + with pytest.raises( + exceptions.RefreshError, match="returned expiry is not a datetime object" + ): + creds.refresh(request) + + assert creds.token is None + assert creds.expiry is None + assert not creds.valid + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_with_refresh_handler_expired_token(self, unused_utcnow): + expected_expiry = datetime.datetime.min + _helpers.CLOCK_SKEW + # Simulate refresh handler returns an expired token. + refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry)) + scopes = ["email", "profile"] + default_scopes = ["https://www.googleapis.com/auth/cloud-platform"] + request = mock.create_autospec(transport.Request) + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + scopes=scopes, + default_scopes=default_scopes, + refresh_handler=refresh_handler, + ) + + with pytest.raises(exceptions.RefreshError, match="already expired"): + creds.refresh(request) + + assert creds.token is None + assert creds.expiry is None + assert not creds.valid + # Confirm refresh handler called with the expected arguments. + refresh_handler.assert_called_with(request, scopes=scopes) + @mock.patch("google.oauth2.reauth.refresh_grant", autospec=True) @mock.patch( "google.auth._helpers.utcnow", @@ -527,6 +786,32 @@ def test_pickle_and_unpickle(self): for attr in list(creds.__dict__): assert getattr(creds, attr) == getattr(unpickled, attr) + def test_pickle_and_unpickle_with_refresh_handler(self): + expected_expiry = _helpers.utcnow() + datetime.timedelta(seconds=2800) + refresh_handler = mock.Mock(return_value=("TOKEN", expected_expiry)) + + creds = credentials.Credentials( + token=None, + refresh_token=None, + token_uri=None, + client_id=None, + client_secret=None, + rapt_token=None, + refresh_handler=refresh_handler, + ) + unpickled = pickle.loads(pickle.dumps(creds)) + + # make sure attributes aren't lost during pickling + assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort() + + for attr in list(creds.__dict__): + # For the _refresh_handler property, the unpickled creds should be + # set to None. + if attr == "_refresh_handler": + assert getattr(unpickled, attr) is None + else: + assert getattr(creds, attr) == getattr(unpickled, attr) + def test_pickle_with_missing_attribute(self): creds = self.make_credentials()