diff --git a/docs/reference/google.auth.credentials_async.rst b/docs/reference/google.auth.credentials_async.rst new file mode 100644 index 000000000..683139aa5 --- /dev/null +++ b/docs/reference/google.auth.credentials_async.rst @@ -0,0 +1,7 @@ +google.auth.credentials\_async module +===================================== + +.. automodule:: google.auth._credentials_async + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.jwt_async.rst b/docs/reference/google.auth.jwt_async.rst new file mode 100644 index 000000000..4e56a6ea3 --- /dev/null +++ b/docs/reference/google.auth.jwt_async.rst @@ -0,0 +1,7 @@ +google.auth.jwt\_async module +============================= + +.. automodule:: google.auth.jwt_async + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.rst b/docs/reference/google.auth.rst index cfcf70357..3acf7dfb8 100644 --- a/docs/reference/google.auth.rst +++ b/docs/reference/google.auth.rst @@ -24,8 +24,10 @@ Submodules google.auth.app_engine google.auth.credentials + google.auth._credentials_async google.auth.environment_vars google.auth.exceptions google.auth.iam google.auth.impersonated_credentials google.auth.jwt + google.auth.jwt_async diff --git a/docs/reference/google.auth.transport.aiohttp_requests.rst b/docs/reference/google.auth.transport.aiohttp_requests.rst new file mode 100644 index 000000000..44fc4e577 --- /dev/null +++ b/docs/reference/google.auth.transport.aiohttp_requests.rst @@ -0,0 +1,7 @@ +google.auth.transport.aiohttp\_requests module +============================================== + +.. automodule:: google.auth.transport._aiohttp_requests + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.auth.transport.rst b/docs/reference/google.auth.transport.rst index 89218632b..f1d198855 100644 --- a/docs/reference/google.auth.transport.rst +++ b/docs/reference/google.auth.transport.rst @@ -12,6 +12,7 @@ Submodules .. toctree:: :maxdepth: 4 + google.auth.transport._aiohttp_requests google.auth.transport.grpc google.auth.transport.mtls google.auth.transport.requests diff --git a/docs/reference/google.oauth2.credentials_async.rst b/docs/reference/google.oauth2.credentials_async.rst new file mode 100644 index 000000000..d0df1e8a3 --- /dev/null +++ b/docs/reference/google.oauth2.credentials_async.rst @@ -0,0 +1,7 @@ +google.oauth2.credentials\_async module +======================================= + +.. automodule:: google.oauth2._credentials_async + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/reference/google.oauth2.rst b/docs/reference/google.oauth2.rst index 1ac9c7320..6f3ba50c2 100644 --- a/docs/reference/google.oauth2.rst +++ b/docs/reference/google.oauth2.rst @@ -13,5 +13,7 @@ Submodules :maxdepth: 4 google.oauth2.credentials + google.oauth2._credentials_async google.oauth2.id_token google.oauth2.service_account + google.oauth2._service_account_async diff --git a/docs/reference/google.oauth2.service_account_async.rst b/docs/reference/google.oauth2.service_account_async.rst new file mode 100644 index 000000000..8aba0d851 --- /dev/null +++ b/docs/reference/google.oauth2.service_account_async.rst @@ -0,0 +1,7 @@ +google.oauth2.service\_account\_async module +============================================ + +.. automodule:: google.oauth2._service_account_async + :members: + :inherited-members: + :show-inheritance: diff --git a/google/auth/__init__.py b/google/auth/__init__.py index 5ca20a362..22d61c66f 100644 --- a/google/auth/__init__.py +++ b/google/auth/__init__.py @@ -18,9 +18,7 @@ from google.auth._default import default, load_credentials_from_file - __all__ = ["default", "load_credentials_from_file"] - # Set default logging handler to avoid "No handler found" warnings. logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/google/auth/_credentials_async.py b/google/auth/_credentials_async.py new file mode 100644 index 000000000..d4d4e2c0e --- /dev/null +++ b/google/auth/_credentials_async.py @@ -0,0 +1,176 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Interfaces for credentials.""" + +import abc +import inspect + +import six + +from google.auth import credentials + + +@six.add_metaclass(abc.ABCMeta) +class Credentials(credentials.Credentials): + """Async inherited credentials class from google.auth.credentials. + The added functionality is the before_request call which requires + async/await syntax. + All credentials have a :attr:`token` that is used for authentication and + may also optionally set an :attr:`expiry` to indicate when the token will + no longer be valid. + + Most credentials will be :attr:`invalid` until :meth:`refresh` is called. + Credentials can do this automatically before the first HTTP request in + :meth:`before_request`. + + Although the token and expiration will change as the credentials are + :meth:`refreshed ` and used, credentials should be considered + immutable. Various credentials will accept configuration such as private + keys, scopes, and other options. These options are not changeable after + construction. Some classes will provide mechanisms to copy the credentials + with modifications such as :meth:`ScopedCredentials.with_scopes`. + """ + + async def before_request(self, request, method, url, headers): + """Performs credential-specific before request logic. + + Refreshes the credentials if necessary, then calls :meth:`apply` to + apply the token to the authentication header. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. + method (str): The request's HTTP method or the RPC method being + invoked. + url (str): The request's URI or the RPC service's URI. + headers (Mapping): The request's headers. + """ + # pylint: disable=unused-argument + # (Subclasses may use these arguments to ascertain information about + # the http request.) + + if not self.valid: + if inspect.iscoroutinefunction(self.refresh): + await self.refresh(request) + else: + self.refresh(request) + self.apply(headers) + + +class CredentialsWithQuotaProject(credentials.CredentialsWithQuotaProject): + """Abstract base for credentials supporting ``with_quota_project`` factory""" + + +class AnonymousCredentials(credentials.AnonymousCredentials, Credentials): + """Credentials that do not provide any authentication information. + + These are useful in the case of services that support anonymous access or + local service emulators that do not use credentials. This class inherits + from the sync anonymous credentials file, but is kept if async credentials + is initialized and we would like anonymous credentials. + """ + + +@six.add_metaclass(abc.ABCMeta) +class ReadOnlyScoped(credentials.ReadOnlyScoped): + """Interface for credentials whose scopes can be queried. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = _credentials_async.with_scopes(scopes=['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = _credentials_async.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +class Scoped(credentials.Scoped): + """Interface for credentials whose scopes can be replaced while copying. + + OAuth 2.0-based credentials allow limiting access using scopes as described + in `RFC6749 Section 3.3`_. + If a credential class implements this interface then the credentials either + use scopes in their implementation. + + Some credentials require scopes in order to obtain a token. You can check + if scoping is necessary with :attr:`requires_scopes`:: + + if credentials.requires_scopes: + # Scoping is required. + credentials = _credentials_async.create_scoped(['one', 'two']) + + Credentials that require scopes must either be constructed with scopes:: + + credentials = SomeScopedCredentials(scopes=['one', 'two']) + + Or must copy an existing instance using :meth:`with_scopes`:: + + scoped_credentials = credentials.with_scopes(scopes=['one', 'two']) + + Some credentials have scopes but do not allow or require scopes to be set, + these credentials can be used as-is. + + .. _RFC6749 Section 3.3: https://tools.ietf.org/html/rfc6749#section-3.3 + """ + + +def with_scopes_if_required(credentials, scopes): + """Creates a copy of the credentials with scopes if scoping is required. + + This helper function is useful when you do not know (or care to know) the + specific type of credentials you are using (such as when you use + :func:`google.auth.default`). This function will call + :meth:`Scoped.with_scopes` if the credentials are scoped credentials and if + the credentials require scoping. Otherwise, it will return the credentials + as-is. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to + scope if necessary. + scopes (Sequence[str]): The list of scopes to use. + + Returns: + google.auth._credentials_async.Credentials: Either a new set of scoped + credentials, or the passed in credentials instance if no scoping + was required. + """ + if isinstance(credentials, Scoped) and credentials.requires_scopes: + return credentials.with_scopes(scopes) + else: + return credentials + + +@six.add_metaclass(abc.ABCMeta) +class Signing(credentials.Signing): + """Interface for credentials that can cryptographically sign messages.""" diff --git a/google/auth/_default_async.py b/google/auth/_default_async.py new file mode 100644 index 000000000..3347fbfdc --- /dev/null +++ b/google/auth/_default_async.py @@ -0,0 +1,266 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Application default credentials. + +Implements application default credentials and project ID detection. +""" + +import io +import json +import os + +import six + +from google.auth import _default +from google.auth import environment_vars +from google.auth import exceptions + + +def load_credentials_from_file(filename, scopes=None, quota_project_id=None): + """Loads Google credentials from a file. + + The credentials file must be a service account key or stored authorized + user credentials. + + Args: + filename (str): The full path to the credentials file. + scopes (Optional[Sequence[str]]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary + quota_project_id (Optional[str]): The project ID used for + quota and billing. + + Returns: + Tuple[google.auth.credentials.Credentials, Optional[str]]: Loaded + credentials and the project ID. Authorized user credentials do not + have the project ID information. + + Raises: + google.auth.exceptions.DefaultCredentialsError: if the file is in the + wrong format or is missing. + """ + if not os.path.exists(filename): + raise exceptions.DefaultCredentialsError( + "File {} was not found.".format(filename) + ) + + with io.open(filename, "r") as file_obj: + try: + info = json.load(file_obj) + except ValueError as caught_exc: + new_exc = exceptions.DefaultCredentialsError( + "File {} is not a valid json file.".format(filename), caught_exc + ) + six.raise_from(new_exc, caught_exc) + + # The type key should indicate that the file is either a service account + # credentials file or an authorized user credentials file. + credential_type = info.get("type") + + if credential_type == _default._AUTHORIZED_USER_TYPE: + from google.oauth2 import _credentials_async as credentials + + try: + credentials = credentials.Credentials.from_authorized_user_info( + info, scopes=scopes + ).with_quota_project(quota_project_id) + except ValueError as caught_exc: + msg = "Failed to load authorized user credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + six.raise_from(new_exc, caught_exc) + if not credentials.quota_project_id: + _default._warn_about_problematic_credentials(credentials) + return credentials, None + + elif credential_type == _default._SERVICE_ACCOUNT_TYPE: + from google.oauth2 import _service_account_async as service_account + + try: + credentials = service_account.Credentials.from_service_account_info( + info, scopes=scopes + ).with_quota_project(quota_project_id) + except ValueError as caught_exc: + msg = "Failed to load service account credentials from {}".format(filename) + new_exc = exceptions.DefaultCredentialsError(msg, caught_exc) + six.raise_from(new_exc, caught_exc) + return credentials, info.get("project_id") + + else: + raise exceptions.DefaultCredentialsError( + "The file {file} does not have a valid type. " + "Type is {type}, expected one of {valid_types}.".format( + file=filename, type=credential_type, valid_types=_default._VALID_TYPES + ) + ) + + +def _get_gcloud_sdk_credentials(): + """Gets the credentials and project ID from the Cloud SDK.""" + from google.auth import _cloud_sdk + + # Check if application default credentials exist. + credentials_filename = _cloud_sdk.get_application_default_credentials_path() + + if not os.path.isfile(credentials_filename): + return None, None + + credentials, project_id = load_credentials_from_file(credentials_filename) + + if not project_id: + project_id = _cloud_sdk.get_project_id() + + return credentials, project_id + + +def _get_explicit_environ_credentials(): + """Gets credentials from the GOOGLE_APPLICATION_CREDENTIALS environment + variable.""" + explicit_file = os.environ.get(environment_vars.CREDENTIALS) + + if explicit_file is not None: + credentials, project_id = load_credentials_from_file( + os.environ[environment_vars.CREDENTIALS] + ) + + return credentials, project_id + + else: + return None, None + + +def _get_gae_credentials(): + """Gets Google App Engine App Identity credentials and project ID.""" + # While this library is normally bundled with app_engine, there are + # some cases where it's not available, so we tolerate ImportError. + + return _default._get_gae_credentials() + + +def _get_gce_credentials(request=None): + """Gets credentials and project ID from the GCE Metadata Service.""" + # Ping requires a transport, but we want application default credentials + # to require no arguments. So, we'll use the _http_client transport which + # uses http.client. This is only acceptable because the metadata server + # doesn't do SSL and never requires proxies. + + # While this library is normally bundled with compute_engine, there are + # some cases where it's not available, so we tolerate ImportError. + + return _default._get_gce_credentials(request) + + +def default_async(scopes=None, request=None, quota_project_id=None): + """Gets the default credentials for the current environment. + + `Application Default Credentials`_ provides an easy way to obtain + credentials to call Google APIs for server-to-server or local applications. + This function acquires credentials from the environment in the following + order: + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON private key file, then it is + loaded and returned. The project ID returned is the project ID defined + in the service account file if available (some older files do not + contain project ID information). + 2. If the `Google Cloud SDK`_ is installed and has application default + credentials set they are loaded and returned. + + To enable application default credentials with the Cloud SDK run:: + + gcloud auth application-default login + + If the Cloud SDK has an active project, the project ID is returned. The + active project can be set using:: + + gcloud config set project + + 3. If the application is running in the `App Engine standard environment`_ + then the credentials and project ID from the `App Identity Service`_ + are used. + 4. If the application is running in `Compute Engine`_ or the + `App Engine flexible environment`_ then the credentials and project ID + are obtained from the `Metadata Service`_. + 5. If no credentials are found, + :class:`~google.auth.exceptions.DefaultCredentialsError` will be raised. + + .. _Application Default Credentials: https://developers.google.com\ + /identity/protocols/application-default-credentials + .. _Google Cloud SDK: https://cloud.google.com/sdk + .. _App Engine standard environment: https://cloud.google.com/appengine + .. _App Identity Service: https://cloud.google.com/appengine/docs/python\ + /appidentity/ + .. _Compute Engine: https://cloud.google.com/compute + .. _App Engine flexible environment: https://cloud.google.com\ + /appengine/flexible + .. _Metadata Service: https://cloud.google.com/compute/docs\ + /storing-retrieving-metadata + + Example:: + + import google.auth + + credentials, project_id = google.auth.default() + + Args: + scopes (Sequence[str]): The list of scopes for the credentials. If + specified, the credentials will automatically be scoped if + necessary. + request (google.auth.transport.Request): An object used to make + HTTP requests. This is used to detect whether the application + is running on Compute Engine. If not specified, then it will + use the standard library http client to make requests. + quota_project_id (Optional[str]): The project ID used for + quota and billing. + Returns: + Tuple[~google.auth.credentials.Credentials, Optional[str]]: + the current environment's credentials and project ID. Project ID + may be None, which indicates that the Project ID could not be + ascertained from the environment. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If no credentials were found, or if the credentials found were + invalid. + """ + from google.auth._credentials_async import with_scopes_if_required + + explicit_project_id = os.environ.get( + environment_vars.PROJECT, os.environ.get(environment_vars.LEGACY_PROJECT) + ) + + checkers = ( + _get_explicit_environ_credentials, + _get_gcloud_sdk_credentials, + _get_gae_credentials, + lambda: _get_gce_credentials(request), + ) + + for checker in checkers: + credentials, project_id = checker() + if credentials is not None: + credentials = with_scopes_if_required( + credentials, scopes + ).with_quota_project(quota_project_id) + effective_project_id = explicit_project_id or project_id + if not effective_project_id: + _default._LOGGER.warning( + "No project ID could be determined. Consider running " + "`gcloud config set project` or setting the %s " + "environment variable", + environment_vars.PROJECT, + ) + return credentials, effective_project_id + + raise exceptions.DefaultCredentialsError(_default._HELP_MESSAGE) diff --git a/google/auth/_jwt_async.py b/google/auth/_jwt_async.py new file mode 100644 index 000000000..49e3026e5 --- /dev/null +++ b/google/auth/_jwt_async.py @@ -0,0 +1,168 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JSON Web Tokens + +Provides support for creating (encoding) and verifying (decoding) JWTs, +especially JWTs generated and consumed by Google infrastructure. + +See `rfc7519`_ for more details on JWTs. + +To encode a JWT use :func:`encode`:: + + from google.auth import crypt + from google.auth import jwt_async + + signer = crypt.Signer(private_key) + payload = {'some': 'payload'} + encoded = jwt_async.encode(signer, payload) + +To decode a JWT and verify claims use :func:`decode`:: + + claims = jwt_async.decode(encoded, certs=public_certs) + +You can also skip verification:: + + claims = jwt_async.decode(encoded, verify=False) + +.. _rfc7519: https://tools.ietf.org/html/rfc7519 + + +NOTE: This async support is experimental and marked internal. This surface may +change in minor releases. +""" + +import google.auth +from google.auth import jwt + + +def encode(signer, payload, header=None, key_id=None): + """Make a signed JWT. + + Args: + signer (google.auth.crypt.Signer): The signer used to sign the JWT. + payload (Mapping[str, str]): The JWT payload. + header (Mapping[str, str]): Additional JWT header payload. + key_id (str): The key id to add to the JWT header. If the + signer has a key id it will be used as the default. If this is + specified it will override the signer's key id. + + Returns: + bytes: The encoded JWT. + """ + return jwt.encode(signer, payload, header, key_id) + + +def decode(token, certs=None, verify=True, audience=None): + """Decode and verify a JWT. + + Args: + token (str): The encoded JWT. + certs (Union[str, bytes, Mapping[str, Union[str, bytes]]]): The + certificate used to validate the JWT signature. If bytes or string, + it must the the public key certificate in PEM format. If a mapping, + it must be a mapping of key IDs to public key certificates in PEM + format. The mapping must contain the same key ID that's specified + in the token's header. + verify (bool): Whether to perform signature and claim validation. + Verification is done by default. + audience (str): The audience claim, 'aud', that this JWT should + contain. If None then the JWT's 'aud' parameter is not verified. + + Returns: + Mapping[str, str]: The deserialized JSON payload in the JWT. + + Raises: + ValueError: if any verification checks failed. + """ + + return jwt.decode(token, certs, verify, audience) + + +class Credentials( + jwt.Credentials, + google.auth._credentials_async.Signing, + google.auth._credentials_async.Credentials, +): + """Credentials that use a JWT as the bearer token. + + These credentials require an "audience" claim. This claim identifies the + intended recipient of the bearer token. + + The constructor arguments determine the claims for the JWT that is + sent with requests. Usually, you'll construct these credentials with + one of the helper constructors as shown in the next section. + + To create JWT credentials using a Google service account private key + JSON file:: + + audience = 'https://pubsub.googleapis.com/google.pubsub.v1.Publisher' + credentials = jwt_async.Credentials.from_service_account_file( + 'service-account.json', + audience=audience) + + If you already have the service account file loaded and parsed:: + + service_account_info = json.load(open('service_account.json')) + credentials = jwt_async.Credentials.from_service_account_info( + service_account_info, + audience=audience) + + Both helper methods pass on arguments to the constructor, so you can + specify the JWT claims:: + + credentials = jwt_async.Credentials.from_service_account_file( + 'service-account.json', + audience=audience, + additional_claims={'meta': 'data'}) + + You can also construct the credentials directly if you have a + :class:`~google.auth.crypt.Signer` instance:: + + credentials = jwt_async.Credentials( + signer, + issuer='your-issuer', + subject='your-subject', + audience=audience) + + The claims are considered immutable. If you want to modify the claims, + you can easily create another instance using :meth:`with_claims`:: + + new_audience = ( + 'https://pubsub.googleapis.com/google.pubsub.v1.Subscriber') + new_credentials = credentials.with_claims(audience=new_audience) + """ + + +class OnDemandCredentials( + jwt.OnDemandCredentials, + google.auth._credentials_async.Signing, + google.auth._credentials_async.Credentials, +): + """On-demand JWT credentials. + + Like :class:`Credentials`, this class uses a JWT as the bearer token for + authentication. However, this class does not require the audience at + construction time. Instead, it will generate a new token on-demand for + each request using the request URI as the audience. It caches tokens + so that multiple requests to the same URI do not incur the overhead + of generating a new token every time. + + This behavior is especially useful for `gRPC`_ clients. A gRPC service may + have multiple audience and gRPC clients may not know all of the audiences + required for accessing a particular service. With these credentials, + no knowledge of the audiences is required ahead of time. + + .. _grpc: http://www.grpc.io/ + """ diff --git a/google/auth/transport/_aiohttp_requests.py b/google/auth/transport/_aiohttp_requests.py new file mode 100644 index 000000000..aaf4e2c0b --- /dev/null +++ b/google/auth/transport/_aiohttp_requests.py @@ -0,0 +1,384 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transport adapter for Async HTTP (aiohttp). + +NOTE: This async support is experimental and marked internal. This surface may +change in minor releases. +""" + +from __future__ import absolute_import + +import asyncio +import functools + +import aiohttp +import six +import urllib3 + +from google.auth import exceptions +from google.auth import transport +from google.auth.transport import requests + +# Timeout can be re-defined depending on async requirement. Currently made 60s more than +# sync timeout. +_DEFAULT_TIMEOUT = 180 # in seconds + + +class _CombinedResponse(transport.Response): + """ + In order to more closely resemble the `requests` interface, where a raw + and deflated content could be accessed at once, this class lazily reads the + stream in `transport.Response` so both return forms can be used. + + The gzip and deflate transfer-encodings are automatically decoded for you + because the default parameter for autodecompress into the ClientSession is set + to False, and therefore we add this class to act as a wrapper for a user to be + able to access both the raw and decoded response bodies - mirroring the sync + implementation. + """ + + def __init__(self, response): + self._response = response + self._raw_content = None + + def _is_compressed(self): + headers = self._response.headers + return "Content-Encoding" in headers and ( + headers["Content-Encoding"] == "gzip" + or headers["Content-Encoding"] == "deflate" + ) + + @property + def status(self): + return self._response.status + + @property + def headers(self): + return self._response.headers + + @property + def data(self): + return self._response.content + + async def raw_content(self): + if self._raw_content is None: + self._raw_content = await self._response.content.read() + return self._raw_content + + async def content(self): + # Load raw_content if necessary + await self.raw_content() + if self._is_compressed(): + decoder = urllib3.response.MultiDecoder( + self._response.headers["Content-Encoding"] + ) + decompressed = decoder.decompress(self._raw_content) + return decompressed + + return self._raw_content + + +class _Response(transport.Response): + """ + Requests transport response adapter. + + Args: + response (requests.Response): The raw Requests response. + """ + + def __init__(self, response): + self._response = response + + @property + def status(self): + return self._response.status + + @property + def headers(self): + return self._response.headers + + @property + def data(self): + return self._response.content + + +class Request(transport.Request): + """Requests request adapter. + + This class is used internally for making requests using asyncio transports + in a consistent way. If you use :class:`AuthorizedSession` you do not need + to construct or use this class directly. + + This class can be useful if you want to manually refresh a + :class:`~google.auth.credentials.Credentials` instance:: + + import google.auth.transport.aiohttp_requests + + request = google.auth.transport.aiohttp_requests.Request() + + credentials.refresh(request) + + Args: + session (aiohttp.ClientSession): An instance :class: aiohttp.ClientSession used + to make HTTP requests. If not specified, a session will be created. + + .. automethod:: __call__ + """ + + def __init__(self, session=None): + self.session = None + + async def __call__( + self, + url, + method="GET", + body=None, + headers=None, + timeout=_DEFAULT_TIMEOUT, + **kwargs, + ): + """ + Make an HTTP request using aiohttp. + + Args: + url (str): The URL to be requested. + method (str): The HTTP method to use for the request. Defaults + to 'GET'. + body (bytes): The payload / body in HTTP request. + headers (Mapping[str, str]): Request headers. + timeout (Optional[int]): The number of seconds to wait for a + response from the server. If not specified or if None, the + requests default timeout will be used. + kwargs: Additional arguments passed through to the underlying + requests :meth:`~requests.Session.request` method. + + Returns: + google.auth.transport.Response: The HTTP response. + + Raises: + google.auth.exceptions.TransportError: If any exception occurred. + """ + + try: + if self.session is None: # pragma: NO COVER + self.session = aiohttp.ClientSession( + auto_decompress=False + ) # pragma: NO COVER + requests._LOGGER.debug("Making request: %s %s", method, url) + response = await self.session.request( + method, url, data=body, headers=headers, timeout=timeout, **kwargs + ) + return _CombinedResponse(response) + + except aiohttp.ClientError as caught_exc: + new_exc = exceptions.TransportError(caught_exc) + six.raise_from(new_exc, caught_exc) + + except asyncio.TimeoutError as caught_exc: + new_exc = exceptions.TransportError(caught_exc) + six.raise_from(new_exc, caught_exc) + + +class AuthorizedSession(aiohttp.ClientSession): + """This is an async implementation of the Authorized Session class. We utilize an + aiohttp transport instance, and the interface mirrors the google.auth.transport.requests + Authorized Session class, except for the change in the transport used in the async use case. + + A Requests Session class with credentials. + + This class is used to perform requests to API endpoints that require + authorization:: + + from google.auth.transport import aiohttp_requests + + async with aiohttp_requests.AuthorizedSession(credentials) as authed_session: + response = await authed_session.request( + 'GET', 'https://www.googleapis.com/storage/v1/b') + + The underlying :meth:`request` implementation handles adding the + credentials' headers to the request and refreshing credentials as needed. + + Args: + credentials (google.auth._credentials_async.Credentials): The credentials to + add to the request. + refresh_status_codes (Sequence[int]): Which HTTP status codes indicate + that credentials should be refreshed and the request should be + retried. + max_refresh_attempts (int): The maximum number of times to attempt to + refresh the credentials and retry the request. + refresh_timeout (Optional[int]): The timeout value in seconds for + credential refresh HTTP requests. + auth_request (google.auth.transport.aiohttp_requests.Request): + (Optional) An instance of + :class:`~google.auth.transport.aiohttp_requests.Request` used when + refreshing credentials. If not passed, + an instance of :class:`~google.auth.transport.aiohttp_requests.Request` + is created. + """ + + def __init__( + self, + credentials, + refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, + max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, + refresh_timeout=None, + auth_request=None, + auto_decompress=False, + ): + super(AuthorizedSession, self).__init__() + self.credentials = credentials + self._refresh_status_codes = refresh_status_codes + self._max_refresh_attempts = max_refresh_attempts + self._refresh_timeout = refresh_timeout + self._is_mtls = False + self._auth_request = auth_request + self._auth_request_session = None + self._loop = asyncio.get_event_loop() + self._refresh_lock = asyncio.Lock() + self._auto_decompress = auto_decompress + + async def request( + self, + method, + url, + data=None, + headers=None, + max_allowed_time=None, + timeout=_DEFAULT_TIMEOUT, + auto_decompress=False, + **kwargs, + ): + + """Implementation of Authorized Session aiohttp request. + + Args: + method: The http request method used (e.g. GET, PUT, DELETE) + + url: The url at which the http request is sent. + + data, headers: These fields parallel the associated data and headers + fields of a regular http request. Using the aiohttp client session to + send the http request allows us to use this parallel corresponding structure + in our Authorized Session class. + + timeout (Optional[Union[float, aiohttp.ClientTimeout]]): + The amount of time in seconds to wait for the server response + with each individual request. + + Can also be passed as an `aiohttp.ClientTimeout` object. + + max_allowed_time (Optional[float]): + If the method runs longer than this, a ``Timeout`` exception is + automatically raised. Unlike the ``timeout` parameter, this + value applies to the total method execution time, even if + multiple requests are made under the hood. + + Mind that it is not guaranteed that the timeout error is raised + at ``max_allowed_time`. It might take longer, for example, if + an underlying request takes a lot of time, but the request + itself does not timeout, e.g. if a large file is being + transmitted. The timout error will be raised after such + request completes. + """ + # Headers come in as bytes which isn't expected behavior, the resumable + # media libraries in some cases expect a str type for the header values, + # but sometimes the operations return these in bytes types. + if headers: + for key in headers.keys(): + if type(headers[key]) is bytes: + headers[key] = headers[key].decode("utf-8") + + async with aiohttp.ClientSession( + auto_decompress=self._auto_decompress + ) as self._auth_request_session: + auth_request = Request(self._auth_request_session) + self._auth_request = auth_request + + # Use a kwarg for this instead of an attribute to maintain + # thread-safety. + _credential_refresh_attempt = kwargs.pop("_credential_refresh_attempt", 0) + # Make a copy of the headers. They will be modified by the credentials + # and we want to pass the original headers if we recurse. + request_headers = headers.copy() if headers is not None else {} + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + remaining_time = max_allowed_time + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + await self.credentials.before_request( + auth_request, method, url, request_headers + ) + + with requests.TimeoutGuard(remaining_time, asyncio.TimeoutError) as guard: + response = await super(AuthorizedSession, self).request( + method, + url, + data=data, + headers=request_headers, + timeout=timeout, + **kwargs, + ) + + remaining_time = guard.remaining_timeout + + if ( + response.status in self._refresh_status_codes + and _credential_refresh_attempt < self._max_refresh_attempts + ): + + requests._LOGGER.info( + "Refreshing credentials due to a %s response. Attempt %s/%s.", + response.status, + _credential_refresh_attempt + 1, + self._max_refresh_attempts, + ) + + # Do not apply the timeout unconditionally in order to not override the + # _auth_request's default timeout. + auth_request = ( + self._auth_request + if timeout is None + else functools.partial(self._auth_request, timeout=timeout) + ) + + with requests.TimeoutGuard( + remaining_time, asyncio.TimeoutError + ) as guard: + async with self._refresh_lock: + await self._loop.run_in_executor( + None, self.credentials.refresh, auth_request + ) + + remaining_time = guard.remaining_timeout + + return await self.request( + method, + url, + data=data, + headers=headers, + max_allowed_time=remaining_time, + timeout=timeout, + _credential_refresh_attempt=_credential_refresh_attempt + 1, + **kwargs, + ) + + return response diff --git a/google/auth/transport/mtls.py b/google/auth/transport/mtls.py index 5b742306b..b40bfbedf 100644 --- a/google/auth/transport/mtls.py +++ b/google/auth/transport/mtls.py @@ -86,9 +86,12 @@ def default_client_encrypted_cert_source(cert_path, key_path): def callback(): try: - _, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials( - generate_encrypted_key=True - ) + ( + _, + cert_bytes, + key_bytes, + passphrase_bytes, + ) = _mtls_helper.get_client_ssl_credentials(generate_encrypted_key=True) with open(cert_path, "wb") as cert_file: cert_file.write(cert_bytes) with open(key_path, "wb") as key_file: diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py new file mode 100644 index 000000000..4817ea40e --- /dev/null +++ b/google/oauth2/_client_async.py @@ -0,0 +1,264 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 async client. + +This is a client for interacting with an OAuth 2.0 authorization server's +token endpoint. + +For more information about the token endpoint, see +`Section 3.1 of rfc6749`_ + +.. _Section 3.1 of rfc6749: https://tools.ietf.org/html/rfc6749#section-3.2 +""" + +import datetime +import json + +import six +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import jwt +from google.oauth2 import _client as client + + +def _handle_error_response(response_body): + """"Translates an error response into an exception. + + Args: + response_body (str): The decoded response data. + + Raises: + google.auth.exceptions.RefreshError + """ + try: + error_data = json.loads(response_body) + error_details = "{}: {}".format( + error_data["error"], error_data.get("error_description") + ) + # If no details could be extracted, use the response data. + except (KeyError, ValueError): + error_details = response_body + + raise exceptions.RefreshError(error_details, response_body) + + +def _parse_expiry(response_data): + """Parses the expiry field from a response into a datetime. + + Args: + response_data (Mapping): The JSON-parsed response data. + + Returns: + Optional[datetime]: The expiration or ``None`` if no expiration was + specified. + """ + expires_in = response_data.get("expires_in", None) + + if expires_in is not None: + return _helpers.utcnow() + datetime.timedelta(seconds=expires_in) + else: + return None + + +async def _token_endpoint_request(request, token_uri, body): + """Makes a request to the OAuth 2.0 authorization server's token endpoint. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + body (Mapping[str, str]): The parameters to send in the request body. + + Returns: + Mapping[str, str]: The JSON-decoded response data. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = urllib.parse.urlencode(body).encode("utf-8") + headers = {"content-type": client._URLENCODED_CONTENT_TYPE} + + retry = 0 + # retry to fetch token for maximum of two times if any internal failure + # occurs. + while True: + + response = await request( + method="POST", url=token_uri, headers=headers, body=body + ) + + # Using data.read() resulted in zlib decompression errors. This may require future investigation. + response_body1 = await response.content() + + response_body = ( + response_body1.decode("utf-8") + if hasattr(response_body1, "decode") + else response_body1 + ) + + response_data = json.loads(response_body) + + if response.status == http_client.OK: + break + else: + error_desc = response_data.get("error_description") or "" + error_code = response_data.get("error") or "" + if ( + any(e == "internal_failure" for e in (error_code, error_desc)) + and retry < 1 + ): + retry += 1 + continue + _handle_error_response(response_body) + + return response_data + + +async def jwt_grant(request, token_uri, assertion): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants. + + For more details, see `rfc7523 section 4`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + assertion (str): The OAuth 2.0 assertion. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: The access token, + expiration, and additional data returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc7523 section 4: https://tools.ietf.org/html/rfc7523#section-4 + """ + body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} + + response_data = await _token_endpoint_request(request, token_uri, body) + + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError("No access token in response.", response_data) + six.raise_from(new_exc, caught_exc) + + expiry = _parse_expiry(response_data) + + return access_token, expiry, response_data + + +async def id_token_jwt_grant(request, token_uri, assertion): + """Implements the JWT Profile for OAuth 2.0 Authorization Grants, but + requests an OpenID Connect ID Token instead of an access token. + + This is a variant on the standard JWT Profile that is currently unique + to Google. This was added for the benefit of authenticating to services + that require ID Tokens instead of access tokens or JWT bearer tokens. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorization server's token endpoint + URI. + assertion (str): JWT token signed by a service account. The token's + payload must include a ``target_audience`` claim. + + Returns: + Tuple[str, Optional[datetime], Mapping[str, str]]: + The (encoded) Open ID Connect ID Token, expiration, and additional + data returned by the endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + """ + body = {"assertion": assertion, "grant_type": client._JWT_GRANT_TYPE} + + response_data = await _token_endpoint_request(request, token_uri, body) + + try: + id_token = response_data["id_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError("No ID token in response.", response_data) + six.raise_from(new_exc, caught_exc) + + payload = jwt.decode(id_token, verify=False) + expiry = datetime.datetime.utcfromtimestamp(payload["exp"]) + + return id_token, expiry, response_data + + +async def refresh_grant( + request, token_uri, refresh_token, client_id, client_secret, scopes=None +): + """Implements the OAuth 2.0 refresh token grant. + + For more details, see `rfc678 section 6`_. + + Args: + request (google.auth.transport.Request): A callable used to make + HTTP requests. + token_uri (str): The OAuth 2.0 authorizations server's token endpoint + URI. + refresh_token (str): The refresh token to use to get a new access + token. + client_id (str): The OAuth 2.0 application's client ID. + client_secret (str): The Oauth 2.0 appliaction's client secret. + scopes (Optional(Sequence[str])): Scopes to request. If present, all + scopes must be authorized for the refresh token. Useful if refresh + token has a wild card scope (e.g. + 'https://www.googleapis.com/auth/any-api'). + + Returns: + Tuple[str, Optional[str], Optional[datetime], Mapping[str, str]]: The + access token, new refresh token, expiration, and additional data + returned by the token endpoint. + + Raises: + google.auth.exceptions.RefreshError: If the token endpoint returned + an error. + + .. _rfc6748 section 6: https://tools.ietf.org/html/rfc6749#section-6 + """ + body = { + "grant_type": client._REFRESH_GRANT_TYPE, + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + } + if scopes: + body["scope"] = " ".join(scopes) + + response_data = await _token_endpoint_request(request, token_uri, body) + + try: + access_token = response_data["access_token"] + except KeyError as caught_exc: + new_exc = exceptions.RefreshError("No access token in response.", response_data) + six.raise_from(new_exc, caught_exc) + + refresh_token = response_data.get("refresh_token", refresh_token) + expiry = _parse_expiry(response_data) + + return access_token, refresh_token, expiry, response_data diff --git a/google/oauth2/_credentials_async.py b/google/oauth2/_credentials_async.py new file mode 100644 index 000000000..eb3e97c08 --- /dev/null +++ b/google/oauth2/_credentials_async.py @@ -0,0 +1,108 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OAuth 2.0 Async Credentials. + +This module provides credentials based on OAuth 2.0 access and refresh tokens. +These credentials usually access resources on behalf of a user (resource +owner). + +Specifically, this is intended to use access tokens acquired using the +`Authorization Code grant`_ and can refresh those tokens using a +optional `refresh token`_. + +Obtaining the initial access and refresh token is outside of the scope of this +module. Consult `rfc6749 section 4.1`_ for complete details on the +Authorization Code grant flow. + +.. _Authorization Code grant: https://tools.ietf.org/html/rfc6749#section-1.3.1 +.. _refresh token: https://tools.ietf.org/html/rfc6749#section-6 +.. _rfc6749 section 4.1: https://tools.ietf.org/html/rfc6749#section-4.1 +""" + +from google.auth import _credentials_async as credentials +from google.auth import _helpers +from google.auth import exceptions +from google.oauth2 import _client_async as _client +from google.oauth2 import credentials as oauth2_credentials + + +class Credentials(oauth2_credentials.Credentials): + """Credentials using OAuth 2.0 access and refresh tokens. + + The credentials are considered immutable. If you want to modify the + quota project, use :meth:`with_quota_project` or :: + + credentials = credentials.with_quota_project('myproject-123) + """ + + @_helpers.copy_docstring(credentials.Credentials) + async def refresh(self, request): + if ( + self._refresh_token is None + or self._token_uri is None + or self._client_id is None + or self._client_secret is None + ): + raise exceptions.RefreshError( + "The credentials do not contain the necessary fields need to " + "refresh the access token. You must specify refresh_token, " + "token_uri, client_id, and client_secret." + ) + + ( + access_token, + refresh_token, + expiry, + grant_response, + ) = await _client.refresh_grant( + request, + self._token_uri, + self._refresh_token, + self._client_id, + self._client_secret, + self._scopes, + ) + + self.token = access_token + self.expiry = expiry + self._refresh_token = refresh_token + self._id_token = grant_response.get("id_token") + + if self._scopes and "scopes" in grant_response: + requested_scopes = frozenset(self._scopes) + granted_scopes = frozenset(grant_response["scopes"].split()) + scopes_requested_but_not_granted = requested_scopes - granted_scopes + if scopes_requested_but_not_granted: + raise exceptions.RefreshError( + "Not all requested scopes were granted by the " + "authorization server, missing scopes {}.".format( + ", ".join(scopes_requested_but_not_granted) + ) + ) + + +class UserAccessTokenCredentials(oauth2_credentials.UserAccessTokenCredentials): + """Access token credentials for user account. + + Obtain the access token for a given user account or the current active + user account with the ``gcloud auth print-access-token`` command. + + Args: + account (Optional[str]): Account to get the access token for. If not + specified, the current active account will be used. + quota_project_id (Optional[str]): The project ID used for quota + and billing. + + """ diff --git a/google/oauth2/_id_token_async.py b/google/oauth2/_id_token_async.py new file mode 100644 index 000000000..f5ef8baff --- /dev/null +++ b/google/oauth2/_id_token_async.py @@ -0,0 +1,267 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Google ID Token helpers. + +Provides support for verifying `OpenID Connect ID Tokens`_, especially ones +generated by Google infrastructure. + +To parse and verify an ID Token issued by Google's OAuth 2.0 authorization +server use :func:`verify_oauth2_token`. To verify an ID Token issued by +Firebase, use :func:`verify_firebase_token`. + +A general purpose ID Token verifier is available as :func:`verify_token`. + +Example:: + + from google.oauth2 import _id_token_async + from google.auth.transport import aiohttp_requests + + request = aiohttp_requests.Request() + + id_info = await _id_token_async.verify_oauth2_token( + token, request, 'my-client-id.example.com') + + if id_info['iss'] != 'https://accounts.google.com': + raise ValueError('Wrong issuer.') + + userid = id_info['sub'] + +By default, this will re-fetch certificates for each verification. Because +Google's public keys are only changed infrequently (on the order of once per +day), you may wish to take advantage of caching to reduce latency and the +potential for network errors. This can be accomplished using an external +library like `CacheControl`_ to create a cache-aware +:class:`google.auth.transport.Request`:: + + import cachecontrol + import google.auth.transport.requests + import requests + + session = requests.session() + cached_session = cachecontrol.CacheControl(session) + request = google.auth.transport.requests.Request(session=cached_session) + +.. _OpenID Connect ID Token: + http://openid.net/specs/openid-connect-core-1_0.html#IDToken +.. _CacheControl: https://cachecontrol.readthedocs.io +""" + +import json +import os + +import six +from six.moves import http_client + +from google.auth import environment_vars +from google.auth import exceptions +from google.auth import jwt +from google.auth.transport import requests +from google.oauth2 import id_token as sync_id_token + + +async def _fetch_certs(request, certs_url): + """Fetches certificates. + + Google-style cerificate endpoints return JSON in the format of + ``{'key id': 'x509 certificate'}``. + + Args: + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + certs_url (str): The certificate endpoint URL. + + Returns: + Mapping[str, str]: A mapping of public key ID to x.509 certificate + data. + """ + response = await request(certs_url, method="GET") + + if response.status != http_client.OK: + raise exceptions.TransportError( + "Could not fetch certificates at {}".format(certs_url) + ) + + data = await response.data.read() + + return json.loads(json.dumps(data)) + + +async def verify_token( + id_token, request, audience=None, certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL +): + """Verifies an ID token and returns the decoded token. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + audience (str): The audience that this token is intended for. If None + then the audience is not verified. + certs_url (str): The URL that specifies the certificates to use to + verify the token. This URL should return JSON in the format of + ``{'key id': 'x509 certificate'}``. + + Returns: + Mapping[str, Any]: The decoded token. + """ + certs = await _fetch_certs(request, certs_url) + + return jwt.decode(id_token, certs=certs, audience=audience) + + +async def verify_oauth2_token(id_token, request, audience=None): + """Verifies an ID Token issued by Google's OAuth 2.0 authorization server. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + audience (str): The audience that this token is intended for. This is + typically your application's OAuth 2.0 client ID. If None then the + audience is not verified. + + Returns: + Mapping[str, Any]: The decoded token. + + Raises: + exceptions.GoogleAuthError: If the issuer is invalid. + """ + idinfo = await verify_token( + id_token, + request, + audience=audience, + certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL, + ) + + if idinfo["iss"] not in sync_id_token._GOOGLE_ISSUERS: + raise exceptions.GoogleAuthError( + "Wrong issuer. 'iss' should be one of the following: {}".format( + sync_id_token._GOOGLE_ISSUERS + ) + ) + + return idinfo + + +async def verify_firebase_token(id_token, request, audience=None): + """Verifies an ID Token issued by Firebase Authentication. + + Args: + id_token (Union[str, bytes]): The encoded token. + request (google.auth.transport.Request): The object used to make + HTTP requests. This must be an aiohttp request. + audience (str): The audience that this token is intended for. This is + typically your Firebase application ID. If None then the audience + is not verified. + + Returns: + Mapping[str, Any]: The decoded token. + """ + return await verify_token( + id_token, + request, + audience=audience, + certs_url=sync_id_token._GOOGLE_APIS_CERTS_URL, + ) + + +async def fetch_id_token(request, audience): + """Fetch the ID Token from the current environment. + + This function acquires ID token from the environment in the following order: + + 1. If the application is running in Compute Engine, App Engine or Cloud Run, + then the ID token are obtained from the metadata server. + 2. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON file, then ID token is + acquired using this service account credentials. + 3. If metadata server doesn't exist and no valid service account credentials + are found, :class:`~google.auth.exceptions.DefaultCredentialsError` will + be raised. + + Example:: + + import google.oauth2._id_token_async + import google.auth.transport.aiohttp_requests + + request = google.auth.transport.aiohttp_requests.Request() + target_audience = "https://pubsub.googleapis.com" + + id_token = await google.oauth2._id_token_async.fetch_id_token(request, target_audience) + + Args: + request (google.auth.transport.aiohttp_requests.Request): A callable used to make + HTTP requests. + audience (str): The audience that this ID token is intended for. + + Returns: + str: The ID token. + + Raises: + ~google.auth.exceptions.DefaultCredentialsError: + If metadata server doesn't exist and no valid service account + credentials are found. + """ + # 1. First try to fetch ID token from metadata server if it exists. The code + # works for GAE and Cloud Run metadata server as well. + try: + from google.auth import compute_engine + + request_new = requests.Request() + credentials = compute_engine.IDTokenCredentials( + request_new, audience, use_metadata_identity_endpoint=True + ) + credentials.refresh(request_new) + + return credentials.token + + except (ImportError, exceptions.TransportError, exceptions.RefreshError): + pass + + # 2. Try to use service account credentials to get ID token. + + # Try to get credentials from the GOOGLE_APPLICATION_CREDENTIALS environment + # variable. + credentials_filename = os.environ.get(environment_vars.CREDENTIALS) + if not ( + credentials_filename + and os.path.exists(credentials_filename) + and os.path.isfile(credentials_filename) + ): + raise exceptions.DefaultCredentialsError( + "Neither metadata server or valid service account credentials are found." + ) + + try: + with open(credentials_filename, "r") as f: + info = json.load(f) + credentials_content = ( + (info.get("type") == "service_account") and info or None + ) + + from google.oauth2 import _service_account_async as service_account + + credentials = service_account.IDTokenCredentials.from_service_account_info( + credentials_content, target_audience=audience + ) + except ValueError as caught_exc: + new_exc = exceptions.DefaultCredentialsError( + "Neither metadata server or valid service account credentials are found.", + caught_exc, + ) + six.raise_from(new_exc, caught_exc) + + await credentials.refresh(request) + return credentials.token diff --git a/google/oauth2/_service_account_async.py b/google/oauth2/_service_account_async.py new file mode 100644 index 000000000..0a4e724a4 --- /dev/null +++ b/google/oauth2/_service_account_async.py @@ -0,0 +1,132 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Service Accounts: JSON Web Token (JWT) Profile for OAuth 2.0 + +NOTE: This file adds asynchronous refresh methods to both credentials +classes, and therefore async/await syntax is required when calling this +method when using service account credentials with asynchronous functionality. +Otherwise, all other methods are inherited from the regular service account +credentials file google.oauth2.service_account + +""" + +from google.auth import _credentials_async as credentials_async +from google.auth import _helpers +from google.oauth2 import _client_async +from google.oauth2 import service_account + + +class Credentials( + service_account.Credentials, credentials_async.Scoped, credentials_async.Credentials +): + """Service account credentials + + Usually, you'll create these credentials with one of the helper + constructors. To create credentials using a Google service account + private key JSON file:: + + credentials = _service_account_async.Credentials.from_service_account_file( + 'service-account.json') + + Or if you already have the service account file loaded:: + + service_account_info = json.load(open('service_account.json')) + credentials = _service_account_async.Credentials.from_service_account_info( + service_account_info) + + Both helper methods pass on arguments to the constructor, so you can + specify additional scopes and a subject if necessary:: + + credentials = _service_account_async.Credentials.from_service_account_file( + 'service-account.json', + scopes=['email'], + subject='user@example.com') + + The credentials are considered immutable. If you want to modify the scopes + or the subject used for delegation, use :meth:`with_scopes` or + :meth:`with_subject`:: + + scoped_credentials = credentials.with_scopes(['email']) + delegated_credentials = credentials.with_subject(subject) + + To add a quota project, use :meth:`with_quota_project`:: + + credentials = credentials.with_quota_project('myproject-123') + """ + + @_helpers.copy_docstring(credentials_async.Credentials) + async def refresh(self, request): + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = await _client_async.jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry + + +class IDTokenCredentials( + service_account.IDTokenCredentials, + credentials_async.Signing, + credentials_async.Credentials, +): + """Open ID Connect ID Token-based service account credentials. + + These credentials are largely similar to :class:`.Credentials`, but instead + of using an OAuth 2.0 Access Token as the bearer token, they use an Open + ID Connect ID Token as the bearer token. These credentials are useful when + communicating to services that require ID Tokens and can not accept access + tokens. + + Usually, you'll create these credentials with one of the helper + constructors. To create credentials using a Google service account + private key JSON file:: + + credentials = ( + _service_account_async.IDTokenCredentials.from_service_account_file( + 'service-account.json')) + + Or if you already have the service account file loaded:: + + service_account_info = json.load(open('service_account.json')) + credentials = ( + _service_account_async.IDTokenCredentials.from_service_account_info( + service_account_info)) + + Both helper methods pass on arguments to the constructor, so you can + specify additional scopes and a subject if necessary:: + + credentials = ( + _service_account_async.IDTokenCredentials.from_service_account_file( + 'service-account.json', + scopes=['email'], + subject='user@example.com')) +` + The credentials are considered immutable. If you want to modify the scopes + or the subject used for delegation, use :meth:`with_scopes` or + :meth:`with_subject`:: + + scoped_credentials = credentials.with_scopes(['email']) + delegated_credentials = credentials.with_subject(subject) + + """ + + @_helpers.copy_docstring(credentials_async.Credentials) + async def refresh(self, request): + assertion = self._make_authorization_grant_assertion() + access_token, expiry, _ = await _client_async.id_token_jwt_grant( + request, self._token_uri, assertion + ) + self.token = access_token + self.expiry = expiry diff --git a/noxfile.py b/noxfile.py index c39f27c47..d497f5305 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,8 +29,18 @@ "responses", "grpcio", ] + +ASYNC_DEPENDENCIES = ["pytest-asyncio", "aioresponses"] + BLACK_VERSION = "black==19.3b0" -BLACK_PATHS = ["google", "tests", "noxfile.py", "setup.py", "docs/conf.py"] +BLACK_PATHS = [ + "google", + "tests", + "tests_async", + "noxfile.py", + "setup.py", + "docs/conf.py", +] @nox.session(python="3.7") @@ -44,6 +54,7 @@ def lint(session): "--application-import-names=google,tests,system_tests", "google", "tests", + "tests_async", ) session.run( "python", "setup.py", "check", "--metadata", "--restructuredtext", "--strict" @@ -64,8 +75,23 @@ def blacken(session): session.run("black", *BLACK_PATHS) -@nox.session(python=["2.7", "3.5", "3.6", "3.7", "3.8"]) +@nox.session(python=["3.6", "3.7", "3.8"]) def unit(session): + session.install(*TEST_DEPENDENCIES) + session.install(*(ASYNC_DEPENDENCIES)) + session.install(".") + session.run( + "pytest", + "--cov=google.auth", + "--cov=google.oauth2", + "--cov=tests", + "tests", + "tests_async", + ) + + +@nox.session(python=["2.7", "3.5"]) +def unit_prev_versions(session): session.install(*TEST_DEPENDENCIES) session.install(".") session.run( @@ -76,14 +102,17 @@ def unit(session): @nox.session(python="3.7") def cover(session): session.install(*TEST_DEPENDENCIES) + session.install(*(ASYNC_DEPENDENCIES)) session.install(".") session.run( "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", + "--cov=tests_async", "--cov-report=", "tests", + "tests_async", ) session.run("coverage", "report", "--show-missing", "--fail-under=100") @@ -117,5 +146,10 @@ def pypy(session): session.install(*TEST_DEPENDENCIES) session.install(".") session.run( - "pytest", "--cov=google.auth", "--cov=google.oauth2", "--cov=tests", "tests" + "pytest", + "--cov=google.auth", + "--cov=google.oauth2", + "--cov=tests", + "tests", + "tests_async", ) diff --git a/setup.py b/setup.py index 1c3578f60..dd58f30f2 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ 'rsa>=3.1.4,<5; python_version >= "3.5"', "setuptools>=40.3.0", "six>=1.9.0", + 'aiohttp >= 3.6.2, < 4.0.0dev; python_version>="3.6"', ) diff --git a/system_tests/noxfile.py b/system_tests/noxfile.py index 14cd3db8e..a039228d9 100644 --- a/system_tests/noxfile.py +++ b/system_tests/noxfile.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google LLC +# Copyright 2020 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +29,6 @@ import nox import py.path - HERE = os.path.abspath(os.path.dirname(__file__)) LIBRARY_DIR = os.path.join(HERE, "..") DATA_DIR = os.path.join(HERE, "data") @@ -169,92 +168,79 @@ def configure_cloud_sdk(session, application_default_credentials, project=False) # Test sesssions -TEST_DEPENDENCIES = ["pytest", "requests"] -PYTHON_VERSIONS = ["2.7", "3.7"] - - -@nox.session(python=PYTHON_VERSIONS) -def service_account(session): - session.install(*TEST_DEPENDENCIES) - session.install(LIBRARY_DIR) - session.run("pytest", "test_service_account.py") - - -@nox.session(python=PYTHON_VERSIONS) -def oauth2_credentials(session): - session.install(*TEST_DEPENDENCIES) - session.install(LIBRARY_DIR) - session.run("pytest", "test_oauth2_credentials.py") +TEST_DEPENDENCIES_ASYNC = ["aiohttp", "pytest-asyncio", "nest-asyncio"] +TEST_DEPENDENCIES_SYNC = ["pytest", "requests"] +PYTHON_VERSIONS_ASYNC = ["3.7"] +PYTHON_VERSIONS_SYNC = ["2.7", "3.7"] -@nox.session(python=PYTHON_VERSIONS) -def impersonated_credentials(session): - session.install(*TEST_DEPENDENCIES) +@nox.session(python=PYTHON_VERSIONS_SYNC) +def service_account_sync(session): + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) - session.run("pytest", "test_impersonated_credentials.py") + session.run("pytest", "system_tests_sync/test_service_account.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def default_explicit_service_account(session): session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE session.env[EXPECT_PROJECT_ENV] = "1" - session.install(*TEST_DEPENDENCIES) + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) - session.run("pytest", "test_default.py", "test_id_token.py") + session.run("pytest", "system_tests_sync/test_default.py", "system_tests_sync/test_id_token.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def default_explicit_authorized_user(session): session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE - session.install(*TEST_DEPENDENCIES) + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) - session.run("pytest", "test_default.py") + session.run("pytest", "system_tests_sync/test_default.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def default_explicit_authorized_user_explicit_project(session): session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE session.env[EXPLICIT_PROJECT_ENV] = "example-project" session.env[EXPECT_PROJECT_ENV] = "1" - session.install(*TEST_DEPENDENCIES) + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) - session.run("pytest", "test_default.py") + session.run("pytest", "system_tests_sync/test_default.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def default_cloud_sdk_service_account(session): configure_cloud_sdk(session, SERVICE_ACCOUNT_FILE) session.env[EXPECT_PROJECT_ENV] = "1" - session.install(*TEST_DEPENDENCIES) + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) - session.run("pytest", "test_default.py") + session.run("pytest", "system_tests_sync/test_default.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def default_cloud_sdk_authorized_user(session): configure_cloud_sdk(session, AUTHORIZED_USER_FILE) - session.install(*TEST_DEPENDENCIES) + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) - session.run("pytest", "test_default.py") + session.run("pytest", "system_tests_sync/test_default.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def default_cloud_sdk_authorized_user_configured_project(session): configure_cloud_sdk(session, AUTHORIZED_USER_FILE, project=True) session.env[EXPECT_PROJECT_ENV] = "1" - session.install(*TEST_DEPENDENCIES) + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) - session.run("pytest", "test_default.py") - + session.run("pytest", "system_tests_sync/test_default.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def compute_engine(session): - session.install(*TEST_DEPENDENCIES) + session.install(*TEST_DEPENDENCIES_SYNC) # unset Application Default Credentials so # credentials are detected from environment del session.virtualenv.env["GOOGLE_APPLICATION_CREDENTIALS"] session.install(LIBRARY_DIR) - session.run("pytest", "test_compute_engine.py") + session.run("pytest", "system_tests_sync/test_compute_engine.py") @nox.session(python=["2.7"]) @@ -283,8 +269,8 @@ def app_engine(session): application_url = GAE_APP_URL_TMPL.format(GAE_TEST_APP_SERVICE, project_id) # Vendor in the test application's dependencies - session.chdir(os.path.join(HERE, "app_engine_test_app")) - session.install(*TEST_DEPENDENCIES) + session.chdir(os.path.join(HERE, "../app_engine_test_app")) + session.install(*TEST_DEPENDENCIES_SYNC) session.install(LIBRARY_DIR) session.run( "pip", "install", "--target", "lib", "-r", "requirements.txt", silent=True @@ -296,20 +282,82 @@ def app_engine(session): # Run the tests session.env["TEST_APP_URL"] = application_url session.chdir(HERE) - session.run("pytest", "test_app_engine.py") + session.run("pytest", "system_tests_sync/test_app_engine.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def grpc(session): session.install(LIBRARY_DIR) - session.install(*TEST_DEPENDENCIES, "google-cloud-pubsub==1.0.0") + session.install(*TEST_DEPENDENCIES_SYNC, "google-cloud-pubsub==1.0.0") session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE - session.run("pytest", "test_grpc.py") + session.run("pytest", "system_tests_sync/test_grpc.py") -@nox.session(python=PYTHON_VERSIONS) +@nox.session(python=PYTHON_VERSIONS_SYNC) def mtls_http(session): session.install(LIBRARY_DIR) - session.install(*TEST_DEPENDENCIES, "pyopenssl") + session.install(*TEST_DEPENDENCIES_SYNC, "pyopenssl") + session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE + session.run("pytest", "system_tests_sync/test_mtls_http.py") + +#ASYNC SYSTEM TESTS + +@nox.session(python=PYTHON_VERSIONS_ASYNC) +def service_account_async(session): + session.install(*(TEST_DEPENDENCIES_SYNC+TEST_DEPENDENCIES_ASYNC)) + session.install(LIBRARY_DIR) + session.run("pytest", "system_tests_async/test_service_account.py") + + +@nox.session(python=PYTHON_VERSIONS_ASYNC) +def default_explicit_service_account_async(session): session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE - session.run("pytest", "test_mtls_http.py") + session.env[EXPECT_PROJECT_ENV] = "1" + session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC)) + session.install(LIBRARY_DIR) + session.run("pytest", "system_tests_async/test_default.py", + "system_tests_async/test_id_token.py") + + +@nox.session(python=PYTHON_VERSIONS_ASYNC) +def default_explicit_authorized_user_async(session): + session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE + session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC)) + session.install(LIBRARY_DIR) + session.run("pytest", "system_tests_async/test_default.py") + + +@nox.session(python=PYTHON_VERSIONS_ASYNC) +def default_explicit_authorized_user_explicit_project_async(session): + session.env[EXPLICIT_CREDENTIALS_ENV] = AUTHORIZED_USER_FILE + session.env[EXPLICIT_PROJECT_ENV] = "example-project" + session.env[EXPECT_PROJECT_ENV] = "1" + session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC)) + session.install(LIBRARY_DIR) + session.run("pytest", "system_tests_async/test_default.py") + + +@nox.session(python=PYTHON_VERSIONS_ASYNC) +def default_cloud_sdk_service_account_async(session): + configure_cloud_sdk(session, SERVICE_ACCOUNT_FILE) + session.env[EXPECT_PROJECT_ENV] = "1" + session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC)) + session.install(LIBRARY_DIR) + session.run("pytest", "system_tests_async/test_default.py") + + +@nox.session(python=PYTHON_VERSIONS_ASYNC) +def default_cloud_sdk_authorized_user_async(session): + configure_cloud_sdk(session, AUTHORIZED_USER_FILE) + session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC)) + session.install(LIBRARY_DIR) + session.run("pytest", "system_tests_async/test_default.py") + + +@nox.session(python=PYTHON_VERSIONS_ASYNC) +def default_cloud_sdk_authorized_user_configured_project_async(session): + configure_cloud_sdk(session, AUTHORIZED_USER_FILE, project=True) + session.env[EXPECT_PROJECT_ENV] = "1" + session.install(*(TEST_DEPENDENCIES_SYNC + TEST_DEPENDENCIES_ASYNC)) + session.install(LIBRARY_DIR) + session.run("pytest", "system_tests_async/test_default.py") diff --git a/system_tests/system_tests_async/conftest.py b/system_tests/system_tests_async/conftest.py new file mode 100644 index 000000000..ecff74c96 --- /dev/null +++ b/system_tests/system_tests_async/conftest.py @@ -0,0 +1,108 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from google.auth import _helpers +import google.auth.transport.requests +import google.auth.transport.urllib3 +import pytest +import requests +import urllib3 + +import aiohttp +from google.auth.transport import _aiohttp_requests as aiohttp_requests +from system_tests.system_tests_sync import conftest as sync_conftest + +ASYNC_REQUESTS_SESSION = aiohttp.ClientSession() + +ASYNC_REQUESTS_SESSION.verify = False +TOKEN_INFO_URL = "https://www.googleapis.com/oauth2/v3/tokeninfo" + + +@pytest.fixture +def service_account_file(): + """The full path to a valid service account key file.""" + yield sync_conftest.SERVICE_ACCOUNT_FILE + + +@pytest.fixture +def impersonated_service_account_file(): + """The full path to a valid service account key file.""" + yield sync_conftest.IMPERSONATED_SERVICE_ACCOUNT_FILE + + +@pytest.fixture +def authorized_user_file(): + """The full path to a valid authorized user file.""" + yield sync_conftest.AUTHORIZED_USER_FILE + +@pytest.fixture(params=["aiohttp"]) +async def http_request(request): + """A transport.request object.""" + yield aiohttp_requests.Request(ASYNC_REQUESTS_SESSION) + +@pytest.fixture +async def token_info(http_request): + """Returns a function that obtains OAuth2 token info.""" + + async def _token_info(access_token=None, id_token=None): + query_params = {} + + if access_token is not None: + query_params["access_token"] = access_token + elif id_token is not None: + query_params["id_token"] = id_token + else: + raise ValueError("No token specified.") + + url = _helpers.update_query(sync_conftest.TOKEN_INFO_URL, query_params) + + response = await http_request(url=url, method="GET") + data = await response.data.read() + + return json.loads(data.decode("utf-8")) + + yield _token_info + + +@pytest.fixture +async def verify_refresh(http_request): + """Returns a function that verifies that credentials can be refreshed.""" + + async def _verify_refresh(credentials): + if credentials.requires_scopes: + credentials = credentials.with_scopes(["email", "profile"]) + + await credentials.refresh(http_request) + + assert credentials.token + assert credentials.valid + + yield _verify_refresh + + +def verify_environment(): + """Checks to make sure that requisite data files are available.""" + if not os.path.isdir(sync_conftest.DATA_DIR): + raise EnvironmentError( + "In order to run system tests, test data must exist in " + "system_tests/data. See CONTRIBUTING.rst for details." + ) + + +def pytest_configure(config): + """Pytest hook that runs before Pytest collects any tests.""" + verify_environment() diff --git a/system_tests/system_tests_async/test_default.py b/system_tests/system_tests_async/test_default.py new file mode 100644 index 000000000..383cbff01 --- /dev/null +++ b/system_tests/system_tests_async/test_default.py @@ -0,0 +1,30 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest + +import google.auth + +EXPECT_PROJECT_ID = os.environ.get("EXPECT_PROJECT_ID") + +@pytest.mark.asyncio +async def test_application_default_credentials(verify_refresh): + credentials, project_id = google.auth.default_async() + #breakpoint() + + if EXPECT_PROJECT_ID is not None: + assert project_id is not None + + await verify_refresh(credentials) diff --git a/system_tests/system_tests_async/test_id_token.py b/system_tests/system_tests_async/test_id_token.py new file mode 100644 index 000000000..a21b137b6 --- /dev/null +++ b/system_tests/system_tests_async/test_id_token.py @@ -0,0 +1,25 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from google.auth import jwt +import google.oauth2._id_token_async + +@pytest.mark.asyncio +async def test_fetch_id_token(http_request): + audience = "https://pubsub.googleapis.com" + token = await google.oauth2._id_token_async.fetch_id_token(http_request, audience) + + _, payload, _, _ = jwt._unverified_decode(token) + assert payload["aud"] == audience diff --git a/system_tests/system_tests_async/test_service_account.py b/system_tests/system_tests_async/test_service_account.py new file mode 100644 index 000000000..c1c16ccd7 --- /dev/null +++ b/system_tests/system_tests_async/test_service_account.py @@ -0,0 +1,53 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from google.auth import _helpers +from google.auth import exceptions +from google.auth import iam +from google.oauth2 import _service_account_async + + +@pytest.fixture +def credentials(service_account_file): + yield _service_account_async.Credentials.from_service_account_file(service_account_file) + + +@pytest.mark.asyncio +async def test_refresh_no_scopes(http_request, credentials): + """ + We expect the http request to refresh credentials + without scopes provided to throw an error. + """ + with pytest.raises(exceptions.RefreshError): + await credentials.refresh(http_request) + +@pytest.mark.asyncio +async def test_refresh_success(http_request, credentials, token_info): + credentials = credentials.with_scopes(["email", "profile"]) + await credentials.refresh(http_request) + + assert credentials.token + + info = await token_info(credentials.token) + + assert info["email"] == credentials.service_account_email + info_scopes = _helpers.string_to_scopes(info["scope"]) + assert set(info_scopes) == set( + [ + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + ] + ) diff --git a/system_tests/.gitignore b/system_tests/system_tests_sync/.gitignore similarity index 100% rename from system_tests/.gitignore rename to system_tests/system_tests_sync/.gitignore diff --git a/system_tests/__init__.py b/system_tests/system_tests_sync/__init__.py similarity index 100% rename from system_tests/__init__.py rename to system_tests/system_tests_sync/__init__.py diff --git a/system_tests/app_engine_test_app/.gitignore b/system_tests/system_tests_sync/app_engine_test_app/.gitignore similarity index 100% rename from system_tests/app_engine_test_app/.gitignore rename to system_tests/system_tests_sync/app_engine_test_app/.gitignore diff --git a/system_tests/app_engine_test_app/app.yaml b/system_tests/system_tests_sync/app_engine_test_app/app.yaml similarity index 100% rename from system_tests/app_engine_test_app/app.yaml rename to system_tests/system_tests_sync/app_engine_test_app/app.yaml diff --git a/system_tests/app_engine_test_app/appengine_config.py b/system_tests/system_tests_sync/app_engine_test_app/appengine_config.py similarity index 100% rename from system_tests/app_engine_test_app/appengine_config.py rename to system_tests/system_tests_sync/app_engine_test_app/appengine_config.py diff --git a/system_tests/app_engine_test_app/main.py b/system_tests/system_tests_sync/app_engine_test_app/main.py similarity index 100% rename from system_tests/app_engine_test_app/main.py rename to system_tests/system_tests_sync/app_engine_test_app/main.py diff --git a/system_tests/app_engine_test_app/requirements.txt b/system_tests/system_tests_sync/app_engine_test_app/requirements.txt similarity index 100% rename from system_tests/app_engine_test_app/requirements.txt rename to system_tests/system_tests_sync/app_engine_test_app/requirements.txt diff --git a/system_tests/conftest.py b/system_tests/system_tests_sync/conftest.py similarity index 96% rename from system_tests/conftest.py rename to system_tests/system_tests_sync/conftest.py index 02de84664..37a6fd346 100644 --- a/system_tests/conftest.py +++ b/system_tests/system_tests_sync/conftest.py @@ -24,12 +24,11 @@ HERE = os.path.dirname(__file__) -DATA_DIR = os.path.join(HERE, "data") +DATA_DIR = os.path.join(HERE, "../data") IMPERSONATED_SERVICE_ACCOUNT_FILE = os.path.join( DATA_DIR, "impersonated_service_account.json" ) SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "service_account.json") -AUTHORIZED_USER_FILE = os.path.join(DATA_DIR, "authorized_user.json") URLLIB3_HTTP = urllib3.PoolManager(retries=False) REQUESTS_SESSION = requests.Session() REQUESTS_SESSION.verify = False diff --git a/system_tests/secrets.tar.enc b/system_tests/system_tests_sync/secrets.tar.enc similarity index 100% rename from system_tests/secrets.tar.enc rename to system_tests/system_tests_sync/secrets.tar.enc diff --git a/system_tests/test_app_engine.py b/system_tests/system_tests_sync/test_app_engine.py similarity index 100% rename from system_tests/test_app_engine.py rename to system_tests/system_tests_sync/test_app_engine.py diff --git a/system_tests/test_compute_engine.py b/system_tests/system_tests_sync/test_compute_engine.py similarity index 100% rename from system_tests/test_compute_engine.py rename to system_tests/system_tests_sync/test_compute_engine.py diff --git a/system_tests/test_default.py b/system_tests/system_tests_sync/test_default.py similarity index 100% rename from system_tests/test_default.py rename to system_tests/system_tests_sync/test_default.py diff --git a/system_tests/test_grpc.py b/system_tests/system_tests_sync/test_grpc.py similarity index 100% rename from system_tests/test_grpc.py rename to system_tests/system_tests_sync/test_grpc.py diff --git a/system_tests/test_id_token.py b/system_tests/system_tests_sync/test_id_token.py similarity index 100% rename from system_tests/test_id_token.py rename to system_tests/system_tests_sync/test_id_token.py diff --git a/system_tests/test_impersonated_credentials.py b/system_tests/system_tests_sync/test_impersonated_credentials.py similarity index 100% rename from system_tests/test_impersonated_credentials.py rename to system_tests/system_tests_sync/test_impersonated_credentials.py diff --git a/system_tests/test_mtls_http.py b/system_tests/system_tests_sync/test_mtls_http.py similarity index 100% rename from system_tests/test_mtls_http.py rename to system_tests/system_tests_sync/test_mtls_http.py diff --git a/system_tests/test_oauth2_credentials.py b/system_tests/system_tests_sync/test_oauth2_credentials.py similarity index 100% rename from system_tests/test_oauth2_credentials.py rename to system_tests/system_tests_sync/test_oauth2_credentials.py diff --git a/system_tests/test_service_account.py b/system_tests/system_tests_sync/test_service_account.py similarity index 100% rename from system_tests/test_service_account.py rename to system_tests/system_tests_sync/test_service_account.py diff --git a/tests_async/__init__.py b/tests_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_async/conftest.py b/tests_async/conftest.py new file mode 100644 index 000000000..b4e90f0e8 --- /dev/null +++ b/tests_async/conftest.py @@ -0,0 +1,51 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +import mock +import pytest + + +def pytest_configure(): + """Load public certificate and private key.""" + pytest.data_dir = os.path.join( + os.path.abspath(os.path.join(__file__, "../..")), "tests/data" + ) + + with open(os.path.join(pytest.data_dir, "privatekey.pem"), "rb") as fh: + pytest.private_key_bytes = fh.read() + + with open(os.path.join(pytest.data_dir, "public_cert.pem"), "rb") as fh: + pytest.public_cert_bytes = fh.read() + + +@pytest.fixture +def mock_non_existent_module(monkeypatch): + """Mocks a non-existing module in sys.modules. + + Additionally mocks any non-existing modules specified in the dotted path. + """ + + def _mock_non_existent_module(path): + parts = path.split(".") + partial = [] + for part in parts: + partial.append(part) + current_module = ".".join(partial) + if current_module not in sys.modules: + monkeypatch.setitem(sys.modules, current_module, mock.MagicMock()) + + return _mock_non_existent_module diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py new file mode 100644 index 000000000..458937ac1 --- /dev/null +++ b/tests_async/oauth2/test__client_async.py @@ -0,0 +1,297 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json + +import mock +import pytest +import six +from six.moves import http_client +from six.moves import urllib + +from google.auth import _helpers +from google.auth import _jwt_async as jwt +from google.auth import exceptions +from google.oauth2 import _client as sync_client +from google.oauth2 import _client_async as _client +from tests.oauth2 import test__client as test_client + + +def test__handle_error_response(): + response_data = json.dumps({"error": "help", "error_description": "I'm alive"}) + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._handle_error_response(response_data) + + assert excinfo.match(r"help: I\'m alive") + + +def test__handle_error_response_non_json(): + response_data = "Help, I'm alive" + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._handle_error_response(response_data) + + assert excinfo.match(r"Help, I\'m alive") + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +def test__parse_expiry(unused_utcnow): + result = _client._parse_expiry({"expires_in": 500}) + assert result == datetime.datetime.min + datetime.timedelta(seconds=500) + + +def test__parse_expiry_none(): + assert _client._parse_expiry({}) is None + + +def make_request(response_data, status=http_client.OK): + response = mock.AsyncMock(spec=["transport.Response"]) + response.status = status + data = json.dumps(response_data).encode("utf-8") + response.data = mock.AsyncMock(spec=["__call__", "read"]) + response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + response.content = mock.AsyncMock(spec=["__call__"], return_value=data) + request = mock.AsyncMock(spec=["transport.Request"]) + request.return_value = response + return request + + +@pytest.mark.asyncio +async def test__token_endpoint_request(): + + request = make_request({"test": "response"}) + + result = await _client._token_endpoint_request( + request, "http://example.com", {"test": "params"} + ) + + # Check request call + request.assert_called_with( + method="POST", + url="http://example.com", + headers={"content-type": "application/x-www-form-urlencoded"}, + body="test=params".encode("utf-8"), + ) + + # Check result + assert result == {"test": "response"} + + +@pytest.mark.asyncio +async def test__token_endpoint_request_error(): + request = make_request({}, status=http_client.BAD_REQUEST) + + with pytest.raises(exceptions.RefreshError): + await _client._token_endpoint_request(request, "http://example.com", {}) + + +@pytest.mark.asyncio +async def test__token_endpoint_request_internal_failure_error(): + request = make_request( + {"error_description": "internal_failure"}, status=http_client.BAD_REQUEST + ) + + with pytest.raises(exceptions.RefreshError): + await _client._token_endpoint_request( + request, "http://example.com", {"error_description": "internal_failure"} + ) + + request = make_request( + {"error": "internal_failure"}, status=http_client.BAD_REQUEST + ) + + with pytest.raises(exceptions.RefreshError): + await _client._token_endpoint_request( + request, "http://example.com", {"error": "internal_failure"} + ) + + +def verify_request_params(request, params): + request_body = request.call_args[1]["body"].decode("utf-8") + request_params = urllib.parse.parse_qs(request_body) + + for key, value in six.iteritems(params): + assert request_params[key][0] == value + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +@pytest.mark.asyncio +async def test_jwt_grant(utcnow): + request = make_request( + {"access_token": "token", "expires_in": 500, "extra": "data"} + ) + + token, expiry, extra_data = await _client.jwt_grant( + request, "http://example.com", "assertion_value" + ) + + # Check request call + verify_request_params( + request, + {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"}, + ) + + # Check result + assert token == "token" + assert expiry == utcnow() + datetime.timedelta(seconds=500) + assert extra_data["extra"] == "data" + + +@pytest.mark.asyncio +async def test_jwt_grant_no_access_token(): + request = make_request( + { + # No access token. + "expires_in": 500, + "extra": "data", + } + ) + + with pytest.raises(exceptions.RefreshError): + await _client.jwt_grant(request, "http://example.com", "assertion_value") + + +@pytest.mark.asyncio +async def test_id_token_jwt_grant(): + now = _helpers.utcnow() + id_token_expiry = _helpers.datetime_to_secs(now) + id_token = jwt.encode(test_client.SIGNER, {"exp": id_token_expiry}).decode("utf-8") + request = make_request({"id_token": id_token, "extra": "data"}) + + token, expiry, extra_data = await _client.id_token_jwt_grant( + request, "http://example.com", "assertion_value" + ) + + # Check request call + verify_request_params( + request, + {"grant_type": sync_client._JWT_GRANT_TYPE, "assertion": "assertion_value"}, + ) + + # Check result + assert token == id_token + # JWT does not store microseconds + now = now.replace(microsecond=0) + assert expiry == now + assert extra_data["extra"] == "data" + + +@pytest.mark.asyncio +async def test_id_token_jwt_grant_no_access_token(): + request = make_request( + { + # No access token. + "expires_in": 500, + "extra": "data", + } + ) + + with pytest.raises(exceptions.RefreshError): + await _client.id_token_jwt_grant( + request, "http://example.com", "assertion_value" + ) + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +@pytest.mark.asyncio +async def test_refresh_grant(unused_utcnow): + request = make_request( + { + "access_token": "token", + "refresh_token": "new_refresh_token", + "expires_in": 500, + "extra": "data", + } + ) + + token, refresh_token, expiry, extra_data = await _client.refresh_grant( + request, "http://example.com", "refresh_token", "client_id", "client_secret" + ) + + # Check request call + verify_request_params( + request, + { + "grant_type": sync_client._REFRESH_GRANT_TYPE, + "refresh_token": "refresh_token", + "client_id": "client_id", + "client_secret": "client_secret", + }, + ) + + # Check result + assert token == "token" + assert refresh_token == "new_refresh_token" + assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500) + assert extra_data["extra"] == "data" + + +@mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) +@pytest.mark.asyncio +async def test_refresh_grant_with_scopes(unused_utcnow): + request = make_request( + { + "access_token": "token", + "refresh_token": "new_refresh_token", + "expires_in": 500, + "extra": "data", + "scope": test_client.SCOPES_AS_STRING, + } + ) + + token, refresh_token, expiry, extra_data = await _client.refresh_grant( + request, + "http://example.com", + "refresh_token", + "client_id", + "client_secret", + test_client.SCOPES_AS_LIST, + ) + + # Check request call. + verify_request_params( + request, + { + "grant_type": sync_client._REFRESH_GRANT_TYPE, + "refresh_token": "refresh_token", + "client_id": "client_id", + "client_secret": "client_secret", + "scope": test_client.SCOPES_AS_STRING, + }, + ) + + # Check result. + assert token == "token" + assert refresh_token == "new_refresh_token" + assert expiry == datetime.datetime.min + datetime.timedelta(seconds=500) + assert extra_data["extra"] == "data" + + +@pytest.mark.asyncio +async def test_refresh_grant_no_access_token(): + request = make_request( + { + # No access token. + "refresh_token": "new_refresh_token", + "expires_in": 500, + "extra": "data", + } + ) + + with pytest.raises(exceptions.RefreshError): + await _client.refresh_grant( + request, "http://example.com", "refresh_token", "client_id", "client_secret" + ) diff --git a/tests_async/oauth2/test_credentials_async.py b/tests_async/oauth2/test_credentials_async.py new file mode 100644 index 000000000..5c883d614 --- /dev/null +++ b/tests_async/oauth2/test_credentials_async.py @@ -0,0 +1,478 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os +import pickle +import sys + +import mock +import pytest + +from google.auth import _helpers +from google.auth import exceptions +from google.oauth2 import _credentials_async as _credentials_async +from google.oauth2 import credentials +from tests.oauth2 import test_credentials + + +class TestCredentials: + + TOKEN_URI = "https://example.com/oauth2/token" + REFRESH_TOKEN = "refresh_token" + CLIENT_ID = "client_id" + CLIENT_SECRET = "client_secret" + + @classmethod + def make_credentials(cls): + return _credentials_async.Credentials( + token=None, + refresh_token=cls.REFRESH_TOKEN, + token_uri=cls.TOKEN_URI, + client_id=cls.CLIENT_ID, + client_secret=cls.CLIENT_SECRET, + ) + + def test_default_state(self): + credentials = self.make_credentials() + assert not credentials.valid + # Expiration hasn't been set yet + assert not credentials.expired + # Scopes aren't required for these credentials + assert not credentials.requires_scopes + # Test properties + assert credentials.refresh_token == self.REFRESH_TOKEN + assert credentials.token_uri == self.TOKEN_URI + assert credentials.client_id == self.CLIENT_ID + assert credentials.client_secret == self.CLIENT_SECRET + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_refresh_success(self, unused_utcnow, refresh_grant): + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = {"id_token": mock.sentinel.id_token} + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + creds = self.make_credentials() + + # Refresh credentials + await creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + None, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + + # Check that the credentials are valid (have a token and are not + # expired) + assert creds.valid + + @pytest.mark.asyncio + async def test_refresh_no_refresh_token(self): + request = mock.AsyncMock(spec=["transport.Request"]) + credentials_ = _credentials_async.Credentials(token=None, refresh_token=None) + + with pytest.raises(exceptions.RefreshError, match="necessary fields"): + await credentials_.refresh(request) + + request.assert_not_called() + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_credentials_with_scopes_requested_refresh_success( + self, unused_utcnow, refresh_grant + ): + scopes = ["email", "profile"] + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = {"id_token": mock.sentinel.id_token} + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + creds = _credentials_async.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + scopes=scopes, + ) + + # Refresh credentials + await creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + scopes, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.has_scopes(scopes) + + # Check that the credentials are valid (have a token and are not + # expired.) + assert creds.valid + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_credentials_with_scopes_returned_refresh_success( + self, unused_utcnow, refresh_grant + ): + scopes = ["email", "profile"] + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = { + "id_token": mock.sentinel.id_token, + "scopes": " ".join(scopes), + } + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + creds = _credentials_async.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + scopes=scopes, + ) + + # Refresh credentials + await creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + scopes, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.has_scopes(scopes) + + # Check that the credentials are valid (have a token and are not + # expired.) + assert creds.valid + + @mock.patch("google.oauth2._client_async.refresh_grant", autospec=True) + @mock.patch( + "google.auth._helpers.utcnow", + return_value=datetime.datetime.min + _helpers.CLOCK_SKEW, + ) + @pytest.mark.asyncio + async def test_credentials_with_scopes_refresh_failure_raises_refresh_error( + self, unused_utcnow, refresh_grant + ): + scopes = ["email", "profile"] + scopes_returned = ["email"] + token = "token" + expiry = _helpers.utcnow() + datetime.timedelta(seconds=500) + grant_response = { + "id_token": mock.sentinel.id_token, + "scopes": " ".join(scopes_returned), + } + refresh_grant.return_value = ( + # Access token + token, + # New refresh token + None, + # Expiry, + expiry, + # Extra data + grant_response, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + creds = _credentials_async.Credentials( + token=None, + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + scopes=scopes, + ) + + # Refresh credentials + with pytest.raises( + exceptions.RefreshError, match="Not all requested scopes were granted" + ): + await creds.refresh(request) + + # Check jwt grant call. + refresh_grant.assert_called_with( + request, + self.TOKEN_URI, + self.REFRESH_TOKEN, + self.CLIENT_ID, + self.CLIENT_SECRET, + scopes, + ) + + # Check that the credentials have the token and expiry + assert creds.token == token + assert creds.expiry == expiry + assert creds.id_token == mock.sentinel.id_token + assert creds.has_scopes(scopes) + + # Check that the credentials are valid (have a token and are not + # expired.) + assert creds.valid + + def test_apply_with_quota_project_id(self): + creds = _credentials_async.Credentials( + token="token", + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + quota_project_id="quota-project-123", + ) + + headers = {} + creds.apply(headers) + assert headers["x-goog-user-project"] == "quota-project-123" + + def test_apply_with_no_quota_project_id(self): + creds = _credentials_async.Credentials( + token="token", + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + ) + + headers = {} + creds.apply(headers) + assert "x-goog-user-project" not in headers + + def test_with_quota_project(self): + creds = _credentials_async.Credentials( + token="token", + refresh_token=self.REFRESH_TOKEN, + token_uri=self.TOKEN_URI, + client_id=self.CLIENT_ID, + client_secret=self.CLIENT_SECRET, + quota_project_id="quota-project-123", + ) + + new_creds = creds.with_quota_project("new-project-456") + assert new_creds.quota_project_id == "new-project-456" + headers = {} + creds.apply(headers) + assert "x-goog-user-project" in headers + + def test_from_authorized_user_info(self): + info = test_credentials.AUTH_USER_INFO.copy() + + creds = _credentials_async.Credentials.from_authorized_user_info(info) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes is None + + scopes = ["email", "profile"] + creds = _credentials_async.Credentials.from_authorized_user_info(info, scopes) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes == scopes + + def test_from_authorized_user_file(self): + info = test_credentials.AUTH_USER_INFO.copy() + + creds = _credentials_async.Credentials.from_authorized_user_file( + test_credentials.AUTH_USER_JSON_FILE + ) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes is None + + scopes = ["email", "profile"] + creds = _credentials_async.Credentials.from_authorized_user_file( + test_credentials.AUTH_USER_JSON_FILE, scopes + ) + assert creds.client_secret == info["client_secret"] + assert creds.client_id == info["client_id"] + assert creds.refresh_token == info["refresh_token"] + assert creds.token_uri == credentials._GOOGLE_OAUTH2_TOKEN_ENDPOINT + assert creds.scopes == scopes + + def test_to_json(self): + info = test_credentials.AUTH_USER_INFO.copy() + creds = _credentials_async.Credentials.from_authorized_user_info(info) + + # Test with no `strip` arg + json_output = creds.to_json() + json_asdict = json.loads(json_output) + assert json_asdict.get("token") == creds.token + assert json_asdict.get("refresh_token") == creds.refresh_token + assert json_asdict.get("token_uri") == creds.token_uri + assert json_asdict.get("client_id") == creds.client_id + assert json_asdict.get("scopes") == creds.scopes + assert json_asdict.get("client_secret") == creds.client_secret + + # Test with a `strip` arg + json_output = creds.to_json(strip=["client_secret"]) + json_asdict = json.loads(json_output) + assert json_asdict.get("token") == creds.token + assert json_asdict.get("refresh_token") == creds.refresh_token + assert json_asdict.get("token_uri") == creds.token_uri + assert json_asdict.get("client_id") == creds.client_id + assert json_asdict.get("scopes") == creds.scopes + assert json_asdict.get("client_secret") is None + + def test_pickle_and_unpickle(self): + creds = self.make_credentials() + unpickled = pickle.loads(pickle.dumps(creds)) + + # make sure attributes aren't lost during pickling + assert list(creds.__dict__).sort() == list(unpickled.__dict__).sort() + + for attr in list(creds.__dict__): + assert getattr(creds, attr) == getattr(unpickled, attr) + + def test_pickle_with_missing_attribute(self): + creds = self.make_credentials() + + # remove an optional attribute before pickling + # this mimics a pickle created with a previous class definition with + # fewer attributes + del creds.__dict__["_quota_project_id"] + + unpickled = pickle.loads(pickle.dumps(creds)) + + # Attribute should be initialized by `__setstate__` + assert unpickled.quota_project_id is None + + # pickles are not compatible across versions + @pytest.mark.skipif( + sys.version_info < (3, 5), + reason="pickle file can only be loaded with Python >= 3.5", + ) + def test_unpickle_old_credentials_pickle(self): + # make sure a credentials file pickled with an older + # library version (google-auth==1.5.1) can be unpickled + with open( + os.path.join(test_credentials.DATA_DIR, "old_oauth_credentials_py3.pickle"), + "rb", + ) as f: + credentials = pickle.load(f) + assert credentials.quota_project_id is None + + +class TestUserAccessTokenCredentials(object): + def test_instance(self): + cred = _credentials_async.UserAccessTokenCredentials() + assert cred._account is None + + cred = cred.with_account("account") + assert cred._account == "account" + + @mock.patch("google.auth._cloud_sdk.get_auth_access_token", autospec=True) + def test_refresh(self, get_auth_access_token): + get_auth_access_token.return_value = "access_token" + cred = _credentials_async.UserAccessTokenCredentials() + cred.refresh(None) + assert cred.token == "access_token" + + def test_with_quota_project(self): + cred = _credentials_async.UserAccessTokenCredentials() + quota_project_cred = cred.with_quota_project("project-foo") + + assert quota_project_cred._quota_project_id == "project-foo" + assert quota_project_cred._account == cred._account + + @mock.patch( + "google.oauth2._credentials_async.UserAccessTokenCredentials.apply", + autospec=True, + ) + @mock.patch( + "google.oauth2._credentials_async.UserAccessTokenCredentials.refresh", + autospec=True, + ) + def test_before_request(self, refresh, apply): + cred = _credentials_async.UserAccessTokenCredentials() + cred.before_request(mock.Mock(), "GET", "https://example.com", {}) + refresh.assert_called() + apply.assert_called() diff --git a/tests_async/oauth2/test_id_token.py b/tests_async/oauth2/test_id_token.py new file mode 100644 index 000000000..a46bd615e --- /dev/null +++ b/tests_async/oauth2/test_id_token.py @@ -0,0 +1,205 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import mock +import pytest + +from google.auth import environment_vars +from google.auth import exceptions +import google.auth.compute_engine._metadata +from google.oauth2 import _id_token_async as id_token +from google.oauth2 import id_token as sync_id_token +from tests.oauth2 import test_id_token + + +def make_request(status, data=None): + response = mock.AsyncMock(spec=["transport.Response"]) + response.status = status + + if data is not None: + response.data = mock.AsyncMock(spec=["__call__", "read"]) + response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + + request = mock.AsyncMock(spec=["transport.Request"]) + request.return_value = response + return request + + +@pytest.mark.asyncio +async def test__fetch_certs_success(): + certs = {"1": "cert"} + request = make_request(200, certs) + + returned_certs = await id_token._fetch_certs(request, mock.sentinel.cert_url) + + request.assert_called_once_with(mock.sentinel.cert_url, method="GET") + assert returned_certs == certs + + +@pytest.mark.asyncio +async def test__fetch_certs_failure(): + request = make_request(404) + + with pytest.raises(exceptions.TransportError): + await id_token._fetch_certs(request, mock.sentinel.cert_url) + + request.assert_called_once_with(mock.sentinel.cert_url, method="GET") + + +@mock.patch("google.auth.jwt.decode", autospec=True) +@mock.patch("google.oauth2._id_token_async._fetch_certs", autospec=True) +@pytest.mark.asyncio +async def test_verify_token(_fetch_certs, decode): + result = await id_token.verify_token(mock.sentinel.token, mock.sentinel.request) + + assert result == decode.return_value + _fetch_certs.assert_called_once_with( + mock.sentinel.request, sync_id_token._GOOGLE_OAUTH2_CERTS_URL + ) + decode.assert_called_once_with( + mock.sentinel.token, certs=_fetch_certs.return_value, audience=None + ) + + +@mock.patch("google.auth.jwt.decode", autospec=True) +@mock.patch("google.oauth2._id_token_async._fetch_certs", autospec=True) +@pytest.mark.asyncio +async def test_verify_token_args(_fetch_certs, decode): + result = await id_token.verify_token( + mock.sentinel.token, + mock.sentinel.request, + audience=mock.sentinel.audience, + certs_url=mock.sentinel.certs_url, + ) + + assert result == decode.return_value + _fetch_certs.assert_called_once_with(mock.sentinel.request, mock.sentinel.certs_url) + decode.assert_called_once_with( + mock.sentinel.token, + certs=_fetch_certs.return_value, + audience=mock.sentinel.audience, + ) + + +@mock.patch("google.oauth2._id_token_async.verify_token", autospec=True) +@pytest.mark.asyncio +async def test_verify_oauth2_token(verify_token): + verify_token.return_value = {"iss": "accounts.google.com"} + result = await id_token.verify_oauth2_token( + mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience + ) + + assert result == verify_token.return_value + verify_token.assert_called_once_with( + mock.sentinel.token, + mock.sentinel.request, + audience=mock.sentinel.audience, + certs_url=sync_id_token._GOOGLE_OAUTH2_CERTS_URL, + ) + + +@mock.patch("google.oauth2._id_token_async.verify_token", autospec=True) +@pytest.mark.asyncio +async def test_verify_oauth2_token_invalid_iss(verify_token): + verify_token.return_value = {"iss": "invalid_issuer"} + + with pytest.raises(exceptions.GoogleAuthError): + await id_token.verify_oauth2_token( + mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience + ) + + +@mock.patch("google.oauth2._id_token_async.verify_token", autospec=True) +@pytest.mark.asyncio +async def test_verify_firebase_token(verify_token): + result = await id_token.verify_firebase_token( + mock.sentinel.token, mock.sentinel.request, audience=mock.sentinel.audience + ) + + assert result == verify_token.return_value + verify_token.assert_called_once_with( + mock.sentinel.token, + mock.sentinel.request, + audience=mock.sentinel.audience, + certs_url=sync_id_token._GOOGLE_APIS_CERTS_URL, + ) + + +@pytest.mark.asyncio +async def test_fetch_id_token_from_metadata_server(): + def mock_init(self, request, audience, use_metadata_identity_endpoint): + assert use_metadata_identity_endpoint + self.token = "id_token" + + with mock.patch.multiple( + google.auth.compute_engine.IDTokenCredentials, + __init__=mock_init, + refresh=mock.Mock(), + ): + request = mock.AsyncMock() + token = await id_token.fetch_id_token(request, "https://pubsub.googleapis.com") + assert token == "id_token" + + +@mock.patch.object( + google.auth.compute_engine.IDTokenCredentials, + "__init__", + side_effect=exceptions.TransportError(), +) +@pytest.mark.asyncio +async def test_fetch_id_token_from_explicit_cred_json_file(mock_init, monkeypatch): + monkeypatch.setenv(environment_vars.CREDENTIALS, test_id_token.SERVICE_ACCOUNT_FILE) + + async def mock_refresh(self, request): + self.token = "id_token" + + with mock.patch.object( + google.oauth2._service_account_async.IDTokenCredentials, "refresh", mock_refresh + ): + request = mock.AsyncMock() + token = await id_token.fetch_id_token(request, "https://pubsub.googleapis.com") + assert token == "id_token" + + +@mock.patch.object( + google.auth.compute_engine.IDTokenCredentials, + "__init__", + side_effect=exceptions.TransportError(), +) +@pytest.mark.asyncio +async def test_fetch_id_token_no_cred_json_file(mock_init, monkeypatch): + monkeypatch.delenv(environment_vars.CREDENTIALS, raising=False) + + with pytest.raises(exceptions.DefaultCredentialsError): + request = mock.AsyncMock() + await id_token.fetch_id_token(request, "https://pubsub.googleapis.com") + + +@mock.patch.object( + google.auth.compute_engine.IDTokenCredentials, + "__init__", + side_effect=exceptions.TransportError(), +) +@pytest.mark.asyncio +async def test_fetch_id_token_invalid_cred_file(mock_init, monkeypatch): + not_json_file = os.path.join( + os.path.dirname(__file__), "../../tests/data/public_cert.pem" + ) + monkeypatch.setenv(environment_vars.CREDENTIALS, not_json_file) + + with pytest.raises(exceptions.DefaultCredentialsError): + request = mock.AsyncMock() + await id_token.fetch_id_token(request, "https://pubsub.googleapis.com") diff --git a/tests_async/oauth2/test_service_account_async.py b/tests_async/oauth2/test_service_account_async.py new file mode 100644 index 000000000..40794536c --- /dev/null +++ b/tests_async/oauth2/test_service_account_async.py @@ -0,0 +1,372 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import mock +import pytest + +from google.auth import _helpers +from google.auth import crypt +from google.auth import jwt +from google.auth import transport +from google.oauth2 import _service_account_async as service_account +from tests.oauth2 import test_service_account + + +class TestCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + TOKEN_URI = "https://example.com/oauth2/token" + + @classmethod + def make_credentials(cls): + return service_account.Credentials( + test_service_account.SIGNER, cls.SERVICE_ACCOUNT_EMAIL, cls.TOKEN_URI + ) + + def test_from_service_account_info(self): + credentials = service_account.Credentials.from_service_account_info( + test_service_account.SERVICE_ACCOUNT_INFO + ) + + assert ( + credentials._signer.key_id + == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"] + ) + assert ( + credentials.service_account_email + == test_service_account.SERVICE_ACCOUNT_INFO["client_email"] + ) + assert ( + credentials._token_uri + == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"] + ) + + def test_from_service_account_info_args(self): + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() + scopes = ["email", "profile"] + subject = "subject" + additional_claims = {"meta": "data"} + + credentials = service_account.Credentials.from_service_account_info( + info, scopes=scopes, subject=subject, additional_claims=additional_claims + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials.project_id == info["project_id"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + assert credentials._scopes == scopes + assert credentials._subject == subject + assert credentials._additional_claims == additional_claims + + def test_from_service_account_file(self): + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() + + credentials = service_account.Credentials.from_service_account_file( + test_service_account.SERVICE_ACCOUNT_JSON_FILE + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials.project_id == info["project_id"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + + def test_from_service_account_file_args(self): + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() + scopes = ["email", "profile"] + subject = "subject" + additional_claims = {"meta": "data"} + + credentials = service_account.Credentials.from_service_account_file( + test_service_account.SERVICE_ACCOUNT_JSON_FILE, + subject=subject, + scopes=scopes, + additional_claims=additional_claims, + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials.project_id == info["project_id"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + assert credentials._scopes == scopes + assert credentials._subject == subject + assert credentials._additional_claims == additional_claims + + def test_default_state(self): + credentials = self.make_credentials() + assert not credentials.valid + # Expiration hasn't been set yet + assert not credentials.expired + # Scopes haven't been specified yet + assert credentials.requires_scopes + + def test_sign_bytes(self): + credentials = self.make_credentials() + to_sign = b"123" + signature = credentials.sign_bytes(to_sign) + assert crypt.verify_signature( + to_sign, signature, test_service_account.PUBLIC_CERT_BYTES + ) + + def test_signer(self): + credentials = self.make_credentials() + assert isinstance(credentials.signer, crypt.Signer) + + def test_signer_email(self): + credentials = self.make_credentials() + assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL + + def test_create_scoped(self): + credentials = self.make_credentials() + scopes = ["email", "profile"] + credentials = credentials.with_scopes(scopes) + assert credentials._scopes == scopes + + def test_with_claims(self): + credentials = self.make_credentials() + new_credentials = credentials.with_claims({"meep": "moop"}) + assert new_credentials._additional_claims == {"meep": "moop"} + + def test_with_quota_project(self): + credentials = self.make_credentials() + new_credentials = credentials.with_quota_project("new-project-456") + assert new_credentials.quota_project_id == "new-project-456" + hdrs = {} + new_credentials.apply(hdrs, token="tok") + assert "x-goog-user-project" in hdrs + + def test__make_authorization_grant_assertion(self): + credentials = self.make_credentials() + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + assert payload["aud"] == self.TOKEN_URI + + def test__make_authorization_grant_assertion_scoped(self): + credentials = self.make_credentials() + scopes = ["email", "profile"] + credentials = credentials.with_scopes(scopes) + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) + assert payload["scope"] == "email profile" + + def test__make_authorization_grant_assertion_subject(self): + credentials = self.make_credentials() + subject = "user@example.com" + credentials = credentials.with_subject(subject) + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) + assert payload["sub"] == subject + + @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_refresh_success(self, jwt_grant): + credentials = self.make_credentials() + token = "token" + jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + {}, + ) + request = mock.create_autospec(transport.Request, instance=True) + + # Refresh credentials + await credentials.refresh(request) + + # Check jwt grant call. + assert jwt_grant.called + + called_request, token_uri, assertion = jwt_grant.call_args[0] + assert called_request == request + assert token_uri == credentials._token_uri + assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES) + # No further assertion done on the token, as there are separate tests + # for checking the authorization grant assertion. + + # Check that the credentials have the token. + assert credentials.token == token + + # Check that the credentials are valid (have a token and are not + # expired) + assert credentials.valid + + @mock.patch("google.oauth2._client_async.jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_before_request_refreshes(self, jwt_grant): + credentials = self.make_credentials() + token = "token" + jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + None, + ) + request = mock.create_autospec(transport.Request, instance=True) + + # Credentials should start as invalid + assert not credentials.valid + + # before_request should cause a refresh + await credentials.before_request(request, "GET", "http://example.com?a=1#3", {}) + + # The refresh endpoint should've been called. + assert jwt_grant.called + + # Credentials should now be valid. + assert credentials.valid + + +class TestIDTokenCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + TOKEN_URI = "https://example.com/oauth2/token" + TARGET_AUDIENCE = "https://example.com" + + @classmethod + def make_credentials(cls): + return service_account.IDTokenCredentials( + test_service_account.SIGNER, + cls.SERVICE_ACCOUNT_EMAIL, + cls.TOKEN_URI, + cls.TARGET_AUDIENCE, + ) + + def test_from_service_account_info(self): + credentials = service_account.IDTokenCredentials.from_service_account_info( + test_service_account.SERVICE_ACCOUNT_INFO, + target_audience=self.TARGET_AUDIENCE, + ) + + assert ( + credentials._signer.key_id + == test_service_account.SERVICE_ACCOUNT_INFO["private_key_id"] + ) + assert ( + credentials.service_account_email + == test_service_account.SERVICE_ACCOUNT_INFO["client_email"] + ) + assert ( + credentials._token_uri + == test_service_account.SERVICE_ACCOUNT_INFO["token_uri"] + ) + assert credentials._target_audience == self.TARGET_AUDIENCE + + def test_from_service_account_file(self): + info = test_service_account.SERVICE_ACCOUNT_INFO.copy() + + credentials = service_account.IDTokenCredentials.from_service_account_file( + test_service_account.SERVICE_ACCOUNT_JSON_FILE, + target_audience=self.TARGET_AUDIENCE, + ) + + assert credentials.service_account_email == info["client_email"] + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._token_uri == info["token_uri"] + assert credentials._target_audience == self.TARGET_AUDIENCE + + def test_default_state(self): + credentials = self.make_credentials() + assert not credentials.valid + # Expiration hasn't been set yet + assert not credentials.expired + + def test_sign_bytes(self): + credentials = self.make_credentials() + to_sign = b"123" + signature = credentials.sign_bytes(to_sign) + assert crypt.verify_signature( + to_sign, signature, test_service_account.PUBLIC_CERT_BYTES + ) + + def test_signer(self): + credentials = self.make_credentials() + assert isinstance(credentials.signer, crypt.Signer) + + def test_signer_email(self): + credentials = self.make_credentials() + assert credentials.signer_email == self.SERVICE_ACCOUNT_EMAIL + + def test_with_target_audience(self): + credentials = self.make_credentials() + new_credentials = credentials.with_target_audience("https://new.example.com") + assert new_credentials._target_audience == "https://new.example.com" + + def test_with_quota_project(self): + credentials = self.make_credentials() + new_credentials = credentials.with_quota_project("project-foo") + assert new_credentials._quota_project_id == "project-foo" + + def test__make_authorization_grant_assertion(self): + credentials = self.make_credentials() + token = credentials._make_authorization_grant_assertion() + payload = jwt.decode(token, test_service_account.PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + assert payload["aud"] == self.TOKEN_URI + assert payload["target_audience"] == self.TARGET_AUDIENCE + + @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_refresh_success(self, id_token_jwt_grant): + credentials = self.make_credentials() + token = "token" + id_token_jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + {}, + ) + + request = mock.AsyncMock(spec=["transport.Request"]) + + # Refresh credentials + await credentials.refresh(request) + + # Check jwt grant call. + assert id_token_jwt_grant.called + + called_request, token_uri, assertion = id_token_jwt_grant.call_args[0] + assert called_request == request + assert token_uri == credentials._token_uri + assert jwt.decode(assertion, test_service_account.PUBLIC_CERT_BYTES) + # No further assertion done on the token, as there are separate tests + # for checking the authorization grant assertion. + + # Check that the credentials have the token. + assert credentials.token == token + + # Check that the credentials are valid (have a token and are not + # expired) + assert credentials.valid + + @mock.patch("google.oauth2._client_async.id_token_jwt_grant", autospec=True) + @pytest.mark.asyncio + async def test_before_request_refreshes(self, id_token_jwt_grant): + credentials = self.make_credentials() + token = "token" + id_token_jwt_grant.return_value = ( + token, + _helpers.utcnow() + datetime.timedelta(seconds=500), + None, + ) + request = mock.AsyncMock(spec=["transport.Request"]) + + # Credentials should start as invalid + assert not credentials.valid + + # before_request should cause a refresh + await credentials.before_request(request, "GET", "http://example.com?a=1#3", {}) + + # The refresh endpoint should've been called. + assert id_token_jwt_grant.called + + # Credentials should now be valid. + assert credentials.valid diff --git a/tests_async/test__default_async.py b/tests_async/test__default_async.py new file mode 100644 index 000000000..bca396aee --- /dev/null +++ b/tests_async/test__default_async.py @@ -0,0 +1,468 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import mock +import pytest + +from google.auth import _credentials_async as credentials +from google.auth import _default_async as _default +from google.auth import app_engine +from google.auth import compute_engine +from google.auth import environment_vars +from google.auth import exceptions +from google.oauth2 import _service_account_async as service_account +import google.oauth2.credentials +from tests import test__default as test_default + +MOCK_CREDENTIALS = mock.Mock(spec=credentials.CredentialsWithQuotaProject) +MOCK_CREDENTIALS.with_quota_project.return_value = MOCK_CREDENTIALS + +LOAD_FILE_PATCH = mock.patch( + "google.auth._default_async.load_credentials_from_file", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) + + +def test_load_credentials_from_missing_file(): + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file("") + + assert excinfo.match(r"not found") + + +def test_load_credentials_from_file_invalid_json(tmpdir): + jsonfile = tmpdir.join("invalid.json") + jsonfile.write("{") + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(jsonfile)) + + assert excinfo.match(r"not a valid json file") + + +def test_load_credentials_from_file_invalid_type(tmpdir): + jsonfile = tmpdir.join("invalid.json") + jsonfile.write(json.dumps({"type": "not-a-real-type"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(jsonfile)) + + assert excinfo.match(r"does not have a valid type") + + +def test_load_credentials_from_file_authorized_user(): + credentials, project_id = _default.load_credentials_from_file( + test_default.AUTHORIZED_USER_FILE + ) + assert isinstance(credentials, google.oauth2._credentials_async.Credentials) + assert project_id is None + + +def test_load_credentials_from_file_no_type(tmpdir): + # use the client_secrets.json, which is valid json but not a + # loadable credentials type + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(test_default.CLIENT_SECRETS_FILE) + + assert excinfo.match(r"does not have a valid type") + assert excinfo.match(r"Type is None") + + +def test_load_credentials_from_file_authorized_user_bad_format(tmpdir): + filename = tmpdir.join("authorized_user_bad.json") + filename.write(json.dumps({"type": "authorized_user"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(filename)) + + assert excinfo.match(r"Failed to load authorized user") + assert excinfo.match(r"missing fields") + + +def test_load_credentials_from_file_authorized_user_cloud_sdk(): + with pytest.warns(UserWarning, match="Cloud SDK"): + credentials, project_id = _default.load_credentials_from_file( + test_default.AUTHORIZED_USER_CLOUD_SDK_FILE + ) + assert isinstance(credentials, google.oauth2._credentials_async.Credentials) + assert project_id is None + + # No warning if the json file has quota project id. + credentials, project_id = _default.load_credentials_from_file( + test_default.AUTHORIZED_USER_CLOUD_SDK_WITH_QUOTA_PROJECT_ID_FILE + ) + assert isinstance(credentials, google.oauth2._credentials_async.Credentials) + assert project_id is None + + +def test_load_credentials_from_file_authorized_user_cloud_sdk_with_scopes(): + with pytest.warns(UserWarning, match="Cloud SDK"): + credentials, project_id = _default.load_credentials_from_file( + test_default.AUTHORIZED_USER_CLOUD_SDK_FILE, + scopes=["https://www.google.com/calendar/feeds"], + ) + assert isinstance(credentials, google.oauth2._credentials_async.Credentials) + assert project_id is None + assert credentials.scopes == ["https://www.google.com/calendar/feeds"] + + +def test_load_credentials_from_file_authorized_user_cloud_sdk_with_quota_project(): + credentials, project_id = _default.load_credentials_from_file( + test_default.AUTHORIZED_USER_CLOUD_SDK_FILE, quota_project_id="project-foo" + ) + + assert isinstance(credentials, google.oauth2._credentials_async.Credentials) + assert project_id is None + assert credentials.quota_project_id == "project-foo" + + +def test_load_credentials_from_file_service_account(): + credentials, project_id = _default.load_credentials_from_file( + test_default.SERVICE_ACCOUNT_FILE + ) + assert isinstance(credentials, service_account.Credentials) + assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"] + + +def test_load_credentials_from_file_service_account_with_scopes(): + credentials, project_id = _default.load_credentials_from_file( + test_default.SERVICE_ACCOUNT_FILE, + scopes=["https://www.google.com/calendar/feeds"], + ) + assert isinstance(credentials, service_account.Credentials) + assert project_id == test_default.SERVICE_ACCOUNT_FILE_DATA["project_id"] + assert credentials.scopes == ["https://www.google.com/calendar/feeds"] + + +def test_load_credentials_from_file_service_account_bad_format(tmpdir): + filename = tmpdir.join("serivce_account_bad.json") + filename.write(json.dumps({"type": "service_account"})) + + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default.load_credentials_from_file(str(filename)) + + assert excinfo.match(r"Failed to load service account") + assert excinfo.match(r"missing fields") + + +@mock.patch.dict(os.environ, {}, clear=True) +def test__get_explicit_environ_credentials_no_env(): + assert _default._get_explicit_environ_credentials() == (None, None) + + +@LOAD_FILE_PATCH +def test__get_explicit_environ_credentials(load, monkeypatch): + monkeypatch.setenv(environment_vars.CREDENTIALS, "filename") + + credentials, project_id = _default._get_explicit_environ_credentials() + + assert credentials is MOCK_CREDENTIALS + assert project_id is mock.sentinel.project_id + load.assert_called_with("filename") + + +@LOAD_FILE_PATCH +def test__get_explicit_environ_credentials_no_project_id(load, monkeypatch): + load.return_value = MOCK_CREDENTIALS, None + monkeypatch.setenv(environment_vars.CREDENTIALS, "filename") + + credentials, project_id = _default._get_explicit_environ_credentials() + + assert credentials is MOCK_CREDENTIALS + assert project_id is None + + +@LOAD_FILE_PATCH +@mock.patch( + "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True +) +def test__get_gcloud_sdk_credentials(get_adc_path, load): + get_adc_path.return_value = test_default.SERVICE_ACCOUNT_FILE + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials is MOCK_CREDENTIALS + assert project_id is mock.sentinel.project_id + load.assert_called_with(test_default.SERVICE_ACCOUNT_FILE) + + +@mock.patch( + "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True +) +def test__get_gcloud_sdk_credentials_non_existent(get_adc_path, tmpdir): + non_existent = tmpdir.join("non-existent") + get_adc_path.return_value = str(non_existent) + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials is None + assert project_id is None + + +@mock.patch( + "google.auth._cloud_sdk.get_project_id", + return_value=mock.sentinel.project_id, + autospec=True, +) +@mock.patch("os.path.isfile", return_value=True, autospec=True) +@LOAD_FILE_PATCH +def test__get_gcloud_sdk_credentials_project_id(load, unused_isfile, get_project_id): + # Don't return a project ID from load file, make the function check + # the Cloud SDK project. + load.return_value = MOCK_CREDENTIALS, None + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials == MOCK_CREDENTIALS + assert project_id == mock.sentinel.project_id + assert get_project_id.called + + +@mock.patch("google.auth._cloud_sdk.get_project_id", return_value=None, autospec=True) +@mock.patch("os.path.isfile", return_value=True) +@LOAD_FILE_PATCH +def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_project_id): + # Don't return a project ID from load file, make the function check + # the Cloud SDK project. + load.return_value = MOCK_CREDENTIALS, None + + credentials, project_id = _default._get_gcloud_sdk_credentials() + + assert credentials == MOCK_CREDENTIALS + assert project_id is None + assert get_project_id.called + + +class _AppIdentityModule(object): + """The interface of the App Idenity app engine module. + See https://cloud.google.com/appengine/docs/standard/python/refdocs\ + /google.appengine.api.app_identity.app_identity + """ + + def get_application_id(self): + raise NotImplementedError() + + +@pytest.fixture +def app_identity(monkeypatch): + """Mocks the app_identity module for google.auth.app_engine.""" + app_identity_module = mock.create_autospec(_AppIdentityModule, instance=True) + monkeypatch.setattr(app_engine, "app_identity", app_identity_module) + yield app_identity_module + + +def test__get_gae_credentials(app_identity): + app_identity.get_application_id.return_value = mock.sentinel.project + + credentials, project_id = _default._get_gae_credentials() + + assert isinstance(credentials, app_engine.Credentials) + assert project_id == mock.sentinel.project + + +def test__get_gae_credentials_no_app_engine(): + import sys + + with mock.patch.dict("sys.modules"): + sys.modules["google.auth.app_engine"] = None + credentials, project_id = _default._get_gae_credentials() + assert credentials is None + assert project_id is None + + +def test__get_gae_credentials_no_apis(): + assert _default._get_gae_credentials() == (None, None) + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True +) +@mock.patch( + "google.auth.compute_engine._metadata.get_project_id", + return_value="example-project", + autospec=True, +) +def test__get_gce_credentials(unused_get, unused_ping): + credentials, project_id = _default._get_gce_credentials() + + assert isinstance(credentials, compute_engine.Credentials) + assert project_id == "example-project" + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True +) +def test__get_gce_credentials_no_ping(unused_ping): + credentials, project_id = _default._get_gce_credentials() + + assert credentials is None + assert project_id is None + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=True, autospec=True +) +@mock.patch( + "google.auth.compute_engine._metadata.get_project_id", + side_effect=exceptions.TransportError(), + autospec=True, +) +def test__get_gce_credentials_no_project_id(unused_get, unused_ping): + credentials, project_id = _default._get_gce_credentials() + + assert isinstance(credentials, compute_engine.Credentials) + assert project_id is None + + +def test__get_gce_credentials_no_compute_engine(): + import sys + + with mock.patch.dict("sys.modules"): + sys.modules["google.auth.compute_engine"] = None + credentials, project_id = _default._get_gce_credentials() + assert credentials is None + assert project_id is None + + +@mock.patch( + "google.auth.compute_engine._metadata.ping", return_value=False, autospec=True +) +def test__get_gce_credentials_explicit_request(ping): + _default._get_gce_credentials(mock.sentinel.request) + ping.assert_called_with(request=mock.sentinel.request) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_early_out(unused_get): + assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_explict_project_id(unused_get, monkeypatch): + monkeypatch.setenv(environment_vars.PROJECT, "explicit-env") + assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env") + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_explict_legacy_project_id(unused_get, monkeypatch): + monkeypatch.setenv(environment_vars.LEGACY_PROJECT, "explicit-env") + assert _default.default_async() == (MOCK_CREDENTIALS, "explicit-env") + + +@mock.patch("logging.Logger.warning", autospec=True) +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gcloud_sdk_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gae_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gce_credentials", + return_value=(MOCK_CREDENTIALS, None), + autospec=True, +) +def test_default_without_project_id( + unused_gce, unused_gae, unused_sdk, unused_explicit, logger_warning +): + assert _default.default_async() == (MOCK_CREDENTIALS, None) + logger_warning.assert_called_with(mock.ANY, mock.ANY, mock.ANY) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(None, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gcloud_sdk_credentials", + return_value=(None, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gae_credentials", + return_value=(None, None), + autospec=True, +) +@mock.patch( + "google.auth._default_async._get_gce_credentials", + return_value=(None, None), + autospec=True, +) +def test_default_fail(unused_gce, unused_gae, unused_sdk, unused_explicit): + with pytest.raises(exceptions.DefaultCredentialsError): + assert _default.default_async() + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +@mock.patch( + "google.auth._credentials_async.with_scopes_if_required", + return_value=MOCK_CREDENTIALS, + autospec=True, +) +def test_default_scoped(with_scopes, unused_get): + scopes = ["one", "two"] + + credentials, project_id = _default.default_async(scopes=scopes) + + assert credentials == with_scopes.return_value + assert project_id == mock.sentinel.project_id + with_scopes.assert_called_once_with(MOCK_CREDENTIALS, scopes) + + +@mock.patch( + "google.auth._default_async._get_explicit_environ_credentials", + return_value=(MOCK_CREDENTIALS, mock.sentinel.project_id), + autospec=True, +) +def test_default_no_app_engine_compute_engine_module(unused_get): + """ + google.auth.compute_engine and google.auth.app_engine are both optional + to allow not including them when using this package. This verifies + that default fails gracefully if these modules are absent + """ + import sys + + with mock.patch.dict("sys.modules"): + sys.modules["google.auth.compute_engine"] = None + sys.modules["google.auth.app_engine"] = None + assert _default.default_async() == (MOCK_CREDENTIALS, mock.sentinel.project_id) diff --git a/tests_async/test_credentials_async.py b/tests_async/test_credentials_async.py new file mode 100644 index 000000000..0a4890825 --- /dev/null +++ b/tests_async/test_credentials_async.py @@ -0,0 +1,177 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import pytest + +from google.auth import _credentials_async as credentials +from google.auth import _helpers + + +class CredentialsImpl(credentials.Credentials): + def refresh(self, request): + self.token = request + + def with_quota_project(self, quota_project_id): + raise NotImplementedError() + + +def test_credentials_constructor(): + credentials = CredentialsImpl() + assert not credentials.token + assert not credentials.expiry + assert not credentials.expired + assert not credentials.valid + + +def test_expired_and_valid(): + credentials = CredentialsImpl() + credentials.token = "token" + + assert credentials.valid + assert not credentials.expired + + # Set the expiration to one second more than now plus the clock skew + # accomodation. These credentials should be valid. + credentials.expiry = ( + datetime.datetime.utcnow() + _helpers.CLOCK_SKEW + datetime.timedelta(seconds=1) + ) + + assert credentials.valid + assert not credentials.expired + + # Set the credentials expiration to now. Because of the clock skew + # accomodation, these credentials should report as expired. + credentials.expiry = datetime.datetime.utcnow() + + assert not credentials.valid + assert credentials.expired + + +@pytest.mark.asyncio +async def test_before_request(): + credentials = CredentialsImpl() + request = "token" + headers = {} + + # First call should call refresh, setting the token. + await credentials.before_request(request, "http://example.com", "GET", headers) + assert credentials.valid + assert credentials.token == "token" + assert headers["authorization"] == "Bearer token" + + request = "token2" + headers = {} + + # Second call shouldn't call refresh. + credentials.before_request(request, "http://example.com", "GET", headers) + + assert credentials.valid + assert credentials.token == "token" + + +def test_anonymous_credentials_ctor(): + anon = credentials.AnonymousCredentials() + + assert anon.token is None + assert anon.expiry is None + assert not anon.expired + assert anon.valid + + +def test_anonymous_credentials_refresh(): + anon = credentials.AnonymousCredentials() + + request = object() + with pytest.raises(ValueError): + anon.refresh(request) + + +def test_anonymous_credentials_apply_default(): + anon = credentials.AnonymousCredentials() + headers = {} + anon.apply(headers) + assert headers == {} + with pytest.raises(ValueError): + anon.apply(headers, token="TOKEN") + + +def test_anonymous_credentials_before_request(): + anon = credentials.AnonymousCredentials() + request = object() + method = "GET" + url = "https://example.com/api/endpoint" + headers = {} + anon.before_request(request, method, url, headers) + assert headers == {} + + +class ReadOnlyScopedCredentialsImpl(credentials.ReadOnlyScoped, CredentialsImpl): + @property + def requires_scopes(self): + return super(ReadOnlyScopedCredentialsImpl, self).requires_scopes + + +def test_readonly_scoped_credentials_constructor(): + credentials = ReadOnlyScopedCredentialsImpl() + assert credentials._scopes is None + + +def test_readonly_scoped_credentials_scopes(): + credentials = ReadOnlyScopedCredentialsImpl() + credentials._scopes = ["one", "two"] + assert credentials.scopes == ["one", "two"] + assert credentials.has_scopes(["one"]) + assert credentials.has_scopes(["two"]) + assert credentials.has_scopes(["one", "two"]) + assert not credentials.has_scopes(["three"]) + + +def test_readonly_scoped_credentials_requires_scopes(): + credentials = ReadOnlyScopedCredentialsImpl() + assert not credentials.requires_scopes + + +class RequiresScopedCredentialsImpl(credentials.Scoped, CredentialsImpl): + def __init__(self, scopes=None): + super(RequiresScopedCredentialsImpl, self).__init__() + self._scopes = scopes + + @property + def requires_scopes(self): + return not self.scopes + + def with_scopes(self, scopes): + return RequiresScopedCredentialsImpl(scopes=scopes) + + +def test_create_scoped_if_required_scoped(): + unscoped_credentials = RequiresScopedCredentialsImpl() + scoped_credentials = credentials.with_scopes_if_required( + unscoped_credentials, ["one", "two"] + ) + + assert scoped_credentials is not unscoped_credentials + assert not scoped_credentials.requires_scopes + assert scoped_credentials.has_scopes(["one", "two"]) + + +def test_create_scoped_if_required_not_scopes(): + unscoped_credentials = CredentialsImpl() + scoped_credentials = credentials.with_scopes_if_required( + unscoped_credentials, ["one", "two"] + ) + + assert scoped_credentials is unscoped_credentials diff --git a/tests_async/test_jwt_async.py b/tests_async/test_jwt_async.py new file mode 100644 index 000000000..a35b837b7 --- /dev/null +++ b/tests_async/test_jwt_async.py @@ -0,0 +1,356 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json + +import mock +import pytest + +from google.auth import _jwt_async as jwt_async +from google.auth import crypt +from google.auth import exceptions +from tests import test_jwt + + +@pytest.fixture +def signer(): + return crypt.RSASigner.from_string(test_jwt.PRIVATE_KEY_BYTES, "1") + + +class TestCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + SUBJECT = "subject" + AUDIENCE = "audience" + ADDITIONAL_CLAIMS = {"meta": "data"} + credentials = None + + @pytest.fixture(autouse=True) + def credentials_fixture(self, signer): + self.credentials = jwt_async.Credentials( + signer, + self.SERVICE_ACCOUNT_EMAIL, + self.SERVICE_ACCOUNT_EMAIL, + self.AUDIENCE, + ) + + def test_from_service_account_info(self): + with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + info = json.load(fh) + + credentials = jwt_async.Credentials.from_service_account_info( + info, audience=self.AUDIENCE + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + assert credentials._audience == self.AUDIENCE + + def test_from_service_account_info_args(self): + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt_async.Credentials.from_service_account_info( + info, + subject=self.SUBJECT, + audience=self.AUDIENCE, + additional_claims=self.ADDITIONAL_CLAIMS, + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._audience == self.AUDIENCE + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_service_account_file(self): + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt_async.Credentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + assert credentials._audience == self.AUDIENCE + + def test_from_service_account_file_args(self): + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt_async.Credentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE, + subject=self.SUBJECT, + audience=self.AUDIENCE, + additional_claims=self.ADDITIONAL_CLAIMS, + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._audience == self.AUDIENCE + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_signing_credentials(self): + jwt_from_signing = self.credentials.from_signing_credentials( + self.credentials, audience=mock.sentinel.new_audience + ) + jwt_from_info = jwt_async.Credentials.from_service_account_info( + test_jwt.SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience + ) + + assert isinstance(jwt_from_signing, jwt_async.Credentials) + assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id + assert jwt_from_signing._issuer == jwt_from_info._issuer + assert jwt_from_signing._subject == jwt_from_info._subject + assert jwt_from_signing._audience == jwt_from_info._audience + + def test_default_state(self): + assert not self.credentials.valid + # Expiration hasn't been set yet + assert not self.credentials.expired + + def test_with_claims(self): + new_audience = "new_audience" + new_credentials = self.credentials.with_claims(audience=new_audience) + + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._audience == new_audience + assert new_credentials._additional_claims == self.credentials._additional_claims + assert new_credentials._quota_project_id == self.credentials._quota_project_id + + def test_with_quota_project(self): + quota_project_id = "project-foo" + + new_credentials = self.credentials.with_quota_project(quota_project_id) + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._audience == self.credentials._audience + assert new_credentials._additional_claims == self.credentials._additional_claims + assert new_credentials._quota_project_id == quota_project_id + + def test_sign_bytes(self): + to_sign = b"123" + signature = self.credentials.sign_bytes(to_sign) + assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES) + + def test_signer(self): + assert isinstance(self.credentials.signer, crypt.RSASigner) + + def test_signer_email(self): + assert ( + self.credentials.signer_email + == test_jwt.SERVICE_ACCOUNT_INFO["client_email"] + ) + + def _verify_token(self, token): + payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + return payload + + def test_refresh(self): + self.credentials.refresh(None) + assert self.credentials.valid + assert not self.credentials.expired + + def test_expired(self): + assert not self.credentials.expired + + self.credentials.refresh(None) + assert not self.credentials.expired + + with mock.patch("google.auth._helpers.utcnow") as now: + one_day = datetime.timedelta(days=1) + now.return_value = self.credentials.expiry + one_day + assert self.credentials.expired + + @pytest.mark.asyncio + async def test_before_request(self): + headers = {} + + self.credentials.refresh(None) + await self.credentials.before_request( + None, "GET", "http://example.com?a=1#3", headers + ) + + header_value = headers["authorization"] + _, token = header_value.split(" ") + + # Since the audience is set, it should use the existing token. + assert token.encode("utf-8") == self.credentials.token + + payload = self._verify_token(token) + assert payload["aud"] == self.AUDIENCE + + @pytest.mark.asyncio + async def test_before_request_refreshes(self): + assert not self.credentials.valid + await self.credentials.before_request( + None, "GET", "http://example.com?a=1#3", {} + ) + assert self.credentials.valid + + +class TestOnDemandCredentials(object): + SERVICE_ACCOUNT_EMAIL = "service-account@example.com" + SUBJECT = "subject" + ADDITIONAL_CLAIMS = {"meta": "data"} + credentials = None + + @pytest.fixture(autouse=True) + def credentials_fixture(self, signer): + self.credentials = jwt_async.OnDemandCredentials( + signer, + self.SERVICE_ACCOUNT_EMAIL, + self.SERVICE_ACCOUNT_EMAIL, + max_cache_size=2, + ) + + def test_from_service_account_info(self): + with open(test_jwt.SERVICE_ACCOUNT_JSON_FILE, "r") as fh: + info = json.load(fh) + + credentials = jwt_async.OnDemandCredentials.from_service_account_info(info) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + + def test_from_service_account_info_args(self): + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt_async.OnDemandCredentials.from_service_account_info( + info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_service_account_file(self): + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt_async.OnDemandCredentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == info["client_email"] + + def test_from_service_account_file_args(self): + info = test_jwt.SERVICE_ACCOUNT_INFO.copy() + + credentials = jwt_async.OnDemandCredentials.from_service_account_file( + test_jwt.SERVICE_ACCOUNT_JSON_FILE, + subject=self.SUBJECT, + additional_claims=self.ADDITIONAL_CLAIMS, + ) + + assert credentials._signer.key_id == info["private_key_id"] + assert credentials._issuer == info["client_email"] + assert credentials._subject == self.SUBJECT + assert credentials._additional_claims == self.ADDITIONAL_CLAIMS + + def test_from_signing_credentials(self): + jwt_from_signing = self.credentials.from_signing_credentials(self.credentials) + jwt_from_info = jwt_async.OnDemandCredentials.from_service_account_info( + test_jwt.SERVICE_ACCOUNT_INFO + ) + + assert isinstance(jwt_from_signing, jwt_async.OnDemandCredentials) + assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id + assert jwt_from_signing._issuer == jwt_from_info._issuer + assert jwt_from_signing._subject == jwt_from_info._subject + + def test_default_state(self): + # Credentials are *always* valid. + assert self.credentials.valid + # Credentials *never* expire. + assert not self.credentials.expired + + def test_with_claims(self): + new_claims = {"meep": "moop"} + new_credentials = self.credentials.with_claims(additional_claims=new_claims) + + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._additional_claims == new_claims + + def test_with_quota_project(self): + quota_project_id = "project-foo" + new_credentials = self.credentials.with_quota_project(quota_project_id) + + assert new_credentials._signer == self.credentials._signer + assert new_credentials._issuer == self.credentials._issuer + assert new_credentials._subject == self.credentials._subject + assert new_credentials._additional_claims == self.credentials._additional_claims + assert new_credentials._quota_project_id == quota_project_id + + def test_sign_bytes(self): + to_sign = b"123" + signature = self.credentials.sign_bytes(to_sign) + assert crypt.verify_signature(to_sign, signature, test_jwt.PUBLIC_CERT_BYTES) + + def test_signer(self): + assert isinstance(self.credentials.signer, crypt.RSASigner) + + def test_signer_email(self): + assert ( + self.credentials.signer_email + == test_jwt.SERVICE_ACCOUNT_INFO["client_email"] + ) + + def _verify_token(self, token): + payload = jwt_async.decode(token, test_jwt.PUBLIC_CERT_BYTES) + assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL + return payload + + def test_refresh(self): + with pytest.raises(exceptions.RefreshError): + self.credentials.refresh(None) + + def test_before_request(self): + headers = {} + + self.credentials.before_request( + None, "GET", "http://example.com?a=1#3", headers + ) + + _, token = headers["authorization"].split(" ") + payload = self._verify_token(token) + + assert payload["aud"] == "http://example.com" + + # Making another request should re-use the same token. + self.credentials.before_request(None, "GET", "http://example.com?b=2", headers) + + _, new_token = headers["authorization"].split(" ") + + assert new_token == token + + def test_expired_token(self): + self.credentials._cache["audience"] = ( + mock.sentinel.token, + datetime.datetime.min, + ) + + token = self.credentials._get_jwt_for_audience("audience") + + assert token != mock.sentinel.token diff --git a/tests_async/transport/__init__.py b/tests_async/transport/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_async/transport/async_compliance.py b/tests_async/transport/async_compliance.py new file mode 100644 index 000000000..9c4b173c2 --- /dev/null +++ b/tests_async/transport/async_compliance.py @@ -0,0 +1,133 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import flask +import pytest +from pytest_localserver.http import WSGIServer +from six.moves import http_client + +from google.auth import exceptions +from tests.transport import compliance + + +class RequestResponseTests(object): + @pytest.fixture(scope="module") + def server(self): + """Provides a test HTTP server. + + The test server is automatically created before + a test and destroyed at the end. The server is serving a test + application that can be used to verify requests. + """ + app = flask.Flask(__name__) + app.debug = True + + # pylint: disable=unused-variable + # (pylint thinks the flask routes are unusued.) + @app.route("/basic") + def index(): + header_value = flask.request.headers.get("x-test-header", "value") + headers = {"X-Test-Header": header_value} + return "Basic Content", http_client.OK, headers + + @app.route("/server_error") + def server_error(): + return "Error", http_client.INTERNAL_SERVER_ERROR + + @app.route("/wait") + def wait(): + time.sleep(3) + return "Waited" + + # pylint: enable=unused-variable + + server = WSGIServer(application=app.wsgi_app) + server.start() + yield server + server.stop() + + @pytest.mark.asyncio + async def test_request_basic(self, server): + request = self.make_request() + response = await request(url=server.url + "/basic", method="GET") + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + # Use 13 as this is the length of the data written into the stream. + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_basic_with_http(self, server): + request = self.make_with_parameter_request() + response = await request(url=server.url + "/basic", method="GET") + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + # Use 13 as this is the length of the data written into the stream. + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_with_timeout_success(self, server): + request = self.make_request() + response = await request(url=server.url + "/basic", method="GET", timeout=2) + + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "value" + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_with_timeout_failure(self, server): + request = self.make_request() + + with pytest.raises(exceptions.TransportError): + await request(url=server.url + "/wait", method="GET", timeout=1) + + @pytest.mark.asyncio + async def test_request_headers(self, server): + request = self.make_request() + response = await request( + url=server.url + "/basic", + method="GET", + headers={"x-test-header": "hello world"}, + ) + + assert response.status == http_client.OK + assert response.headers["x-test-header"] == "hello world" + + data = await response.data.read(13) + assert data == b"Basic Content" + + @pytest.mark.asyncio + async def test_request_error(self, server): + request = self.make_request() + + response = await request(url=server.url + "/server_error", method="GET") + assert response.status == http_client.INTERNAL_SERVER_ERROR + data = await response.data.read(5) + assert data == b"Error" + + @pytest.mark.asyncio + async def test_connection_error(self): + request = self.make_request() + + with pytest.raises(exceptions.TransportError): + await request(url="http://{}".format(compliance.NXDOMAIN), method="GET") diff --git a/tests_async/transport/test_aiohttp_requests.py b/tests_async/transport/test_aiohttp_requests.py new file mode 100644 index 000000000..10c31db8f --- /dev/null +++ b/tests_async/transport/test_aiohttp_requests.py @@ -0,0 +1,245 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import aiohttp +from aioresponses import aioresponses, core +import mock +import pytest +from tests_async.transport import async_compliance + +import google.auth._credentials_async +from google.auth.transport import _aiohttp_requests as aiohttp_requests +import google.auth.transport._mtls_helper + + +class TestCombinedResponse: + @pytest.mark.asyncio + async def test__is_compressed(self): + response = core.CallbackResult(headers={"Content-Encoding": "gzip"}) + combined_response = aiohttp_requests._CombinedResponse(response) + compressed = combined_response._is_compressed() + assert compressed + + def test__is_compressed_not(self): + response = core.CallbackResult(headers={"Content-Encoding": "not"}) + combined_response = aiohttp_requests._CombinedResponse(response) + compressed = combined_response._is_compressed() + assert not compressed + + @pytest.mark.asyncio + async def test_raw_content(self): + + mock_response = mock.AsyncMock() + mock_response.content.read.return_value = mock.sentinel.read + combined_response = aiohttp_requests._CombinedResponse(response=mock_response) + raw_content = await combined_response.raw_content() + assert raw_content == mock.sentinel.read + + # Second call to validate the preconfigured path. + combined_response._raw_content = mock.sentinel.stored_raw + raw_content = await combined_response.raw_content() + assert raw_content == mock.sentinel.stored_raw + + @pytest.mark.asyncio + async def test_content(self): + mock_response = mock.AsyncMock() + mock_response.content.read.return_value = mock.sentinel.read + combined_response = aiohttp_requests._CombinedResponse(response=mock_response) + content = await combined_response.content() + assert content == mock.sentinel.read + + @mock.patch( + "google.auth.transport._aiohttp_requests.urllib3.response.MultiDecoder.decompress", + return_value="decompressed", + autospec=True, + ) + @pytest.mark.asyncio + async def test_content_compressed(self, urllib3_mock): + rm = core.RequestMatch( + "url", headers={"Content-Encoding": "gzip"}, payload="compressed" + ) + response = await rm.build_response(core.URL("url")) + + combined_response = aiohttp_requests._CombinedResponse(response=response) + content = await combined_response.content() + + urllib3_mock.assert_called_once() + assert content == "decompressed" + + +class TestResponse: + def test_ctor(self): + response = aiohttp_requests._Response(mock.sentinel.response) + assert response._response == mock.sentinel.response + + @pytest.mark.asyncio + async def test_headers_prop(self): + rm = core.RequestMatch("url", headers={"Content-Encoding": "header prop"}) + mock_response = await rm.build_response(core.URL("url")) + + response = aiohttp_requests._Response(mock_response) + assert response.headers["Content-Encoding"] == "header prop" + + @pytest.mark.asyncio + async def test_status_prop(self): + rm = core.RequestMatch("url", status=123) + mock_response = await rm.build_response(core.URL("url")) + response = aiohttp_requests._Response(mock_response) + assert response.status == 123 + + @pytest.mark.asyncio + async def test_data_prop(self): + mock_response = mock.AsyncMock() + mock_response.content.read.return_value = mock.sentinel.read + response = aiohttp_requests._Response(mock_response) + data = await response.data.read() + assert data == mock.sentinel.read + + +class TestRequestResponse(async_compliance.RequestResponseTests): + def make_request(self): + return aiohttp_requests.Request() + + def make_with_parameter_request(self): + http = mock.create_autospec(aiohttp.ClientSession, instance=True) + return aiohttp_requests.Request(http) + + def test_timeout(self): + http = mock.create_autospec(aiohttp.ClientSession, instance=True) + request = aiohttp_requests.Request(http) + request(url="http://example.com", method="GET", timeout=5) + + +class CredentialsStub(google.auth._credentials_async.Credentials): + def __init__(self, token="token"): + super(CredentialsStub, self).__init__() + self.token = token + + def apply(self, headers, token=None): + headers["authorization"] = self.token + + def refresh(self, request): + self.token += "1" + + +class TestAuthorizedSession(object): + TEST_URL = "http://example.com/" + method = "GET" + + def test_constructor(self): + authed_session = aiohttp_requests.AuthorizedSession(mock.sentinel.credentials) + assert authed_session.credentials == mock.sentinel.credentials + + def test_constructor_with_auth_request(self): + http = mock.create_autospec(aiohttp.ClientSession) + auth_request = aiohttp_requests.Request(http) + + authed_session = aiohttp_requests.AuthorizedSession( + mock.sentinel.credentials, auth_request=auth_request + ) + + assert authed_session._auth_request == auth_request + + @pytest.mark.asyncio + async def test_request(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + + mocked.get(self.TEST_URL, status=200, body="test") + session = aiohttp_requests.AuthorizedSession(credentials) + resp = await session.request( + "GET", + "http://example.com/", + headers={"Keep-Alive": "timeout=5, max=1000", "fake": b"bytes"}, + ) + + assert resp.status == 200 + assert "test" == await resp.text() + + await session.close() + + @pytest.mark.asyncio + async def test_ctx(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.get("http://test.example.com", payload=dict(foo="bar")) + session = aiohttp_requests.AuthorizedSession(credentials) + resp = await session.request("GET", "http://test.example.com") + data = await resp.json() + + assert dict(foo="bar") == data + + await session.close() + + @pytest.mark.asyncio + async def test_http_headers(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.post( + "http://example.com", + payload=dict(), + headers=dict(connection="keep-alive"), + ) + + session = aiohttp_requests.AuthorizedSession(credentials) + resp = await session.request("POST", "http://example.com") + + assert resp.headers["Connection"] == "keep-alive" + + await session.close() + + @pytest.mark.asyncio + async def test_regexp_example(self): + with aioresponses() as mocked: + credentials = mock.Mock(wraps=CredentialsStub()) + mocked.get("http://example.com", status=500) + mocked.get("http://example.com", status=200) + + session1 = aiohttp_requests.AuthorizedSession(credentials) + + resp1 = await session1.request("GET", "http://example.com") + session2 = aiohttp_requests.AuthorizedSession(credentials) + resp2 = await session2.request("GET", "http://example.com") + + assert resp1.status == 500 + assert resp2.status == 200 + + await session1.close() + await session2.close() + + @pytest.mark.asyncio + async def test_request_no_refresh(self): + credentials = mock.Mock(wraps=CredentialsStub()) + with aioresponses() as mocked: + mocked.get("http://example.com", status=200) + authed_session = aiohttp_requests.AuthorizedSession(credentials) + response = await authed_session.request("GET", "http://example.com") + assert response.status == 200 + assert credentials.before_request.called + assert not credentials.refresh.called + + await authed_session.close() + + @pytest.mark.asyncio + async def test_request_refresh(self): + credentials = mock.Mock(wraps=CredentialsStub()) + with aioresponses() as mocked: + mocked.get("http://example.com", status=401) + mocked.get("http://example.com", status=200) + authed_session = aiohttp_requests.AuthorizedSession(credentials) + response = await authed_session.request("GET", "http://example.com") + assert credentials.refresh.called + assert response.status == 200 + + await authed_session.close()