Skip to content

Commit

Permalink
feat: support self-signed JWT flow for service accounts (#774)
Browse files Browse the repository at this point in the history
See [RFC (internal only)](https://docs.google.com/document/d/1SNCVTmW6Rtr__u-_V7nsT9PhSzjj1z0P9fAD3YUgRoc/edit#) and https://aip.dev/auth/4111.

Support the self-signed JWT flow for service accounts by passing `default_scopes` and `default_host` in calls to the auth library and `create_channel`. This depends on features exposed in the following PRs: googleapis/python-api-core#134, googleapis/google-auth-library-python#665.

It may be easier to look at https://github.com/googleapis/python-translate/pull/107/files for a diff on a real library.

This change is written so that the library is (temporarily) compatible with older `google-api-core` and `google-auth` versions. Because of this it not possible to reach 100% coverage on a single unit test run. `pytest` runs twice in two of the `nox` sessions.

Miscellaneous changes:
- sprinkled in `__init__.py` files in subdirs of the `test/` directory, as otherwise pytest-cov seems to fail to collect coverage properly in some instances.
- new dependency on `packaging` for Version comparison https://pypi.org/project/packaging/

Co-authored-by: Brent Shaffer <betterbrent@google.com>
  • Loading branch information
busunkim96 and bshaffer committed Apr 21, 2021
1 parent 7ca9222 commit 89d6f35
Show file tree
Hide file tree
Showing 12 changed files with 458 additions and 62 deletions.
Expand Up @@ -2,10 +2,12 @@

{% block content %}
import abc
import typing
from typing import Awaitable, Callable, Dict, Optional, Sequence, Union
import packaging.version
import pkg_resources

from google import auth # type: ignore
import google.api_core # type: ignore
from google.api_core import exceptions # type: ignore
from google.api_core import gapic_v1 # type: ignore
from google.api_core import retry as retries # type: ignore
Expand Down Expand Up @@ -34,6 +36,18 @@ try:
except pkg_resources.DistributionNotFound:
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo()

try:
# google.auth.__version__ was added in 1.26.0
_GOOGLE_AUTH_VERSION = auth.__version__
except AttributeError:
try: # try pkg_resources if it is available
_GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class {{ service.name }}Transport(abc.ABC):
"""Abstract transport class for {{ service.name }}."""

Expand All @@ -43,13 +57,15 @@ class {{ service.name }}Transport(abc.ABC):
{%- endfor %}
)

DEFAULT_HOST: str = {% if service.host %}'{{ service.host }}'{% else %}{{ '' }}{% endif %}

def __init__(
self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
host: str = DEFAULT_HOST,
credentials: credentials.Credentials = None,
credentials_file: typing.Optional[str] = None,
scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES,
quota_project_id: typing.Optional[str] = None,
credentials_file: Optional[str] = None,
scopes: Optional[Sequence[str]] = None,
quota_project_id: Optional[str] = None,
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
**kwargs,
) -> None:
Expand All @@ -66,7 +82,7 @@ class {{ service.name }}Transport(abc.ABC):
credentials_file (Optional[str]): A file with credentials that can
be loaded with :func:`google.auth.load_credentials_from_file`.
This argument is mutually exclusive with credentials.
scope (Optional[Sequence[str]]): A list of scopes.
scopes (Optional[Sequence[str]]): A list of scopes.
quota_project_id (Optional[str]): An optional project to use for billing
and quota.
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
Expand All @@ -80,6 +96,8 @@ class {{ service.name }}Transport(abc.ABC):
host += ':443'
self._host = host

scopes_kwargs = self._get_scopes_kwargs(self._host, scopes)

# Save the scopes.
self._scopes = scopes or self.AUTH_SCOPES

Expand All @@ -91,17 +109,59 @@ class {{ service.name }}Transport(abc.ABC):
if credentials_file is not None:
credentials, _ = auth.load_credentials_from_file(
credentials_file,
scopes=self._scopes,
**scopes_kwargs,
quota_project_id=quota_project_id
)

elif credentials is None:
credentials, _ = auth.default(scopes=self._scopes, quota_project_id=quota_project_id)
credentials, _ = auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# Save the credentials.
self._credentials = credentials


# TODO(busunkim): These two class methods are in the base transport
# to avoid duplicating code across the transport classes. These functions
# should be deleted once the minimum required versions of google-api-core
# and google-auth are increased.

# TODO: Remove this function once google-auth >= 1.25.0 is required
@classmethod
def _get_scopes_kwargs(cls, host: str, scopes: Optional[Sequence[str]]) -> Dict[str, Optional[Sequence[str]]]:
"""Returns scopes kwargs to pass to google-auth methods depending on the google-auth version"""

scopes_kwargs = {}

if _GOOGLE_AUTH_VERSION and (
packaging.version.parse(_GOOGLE_AUTH_VERSION)
>= packaging.version.parse("1.25.0")
):
scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES}
else:
scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES}

return scopes_kwargs

# TODO: Remove this function once google-api-core >= 1.26.0 is required
@classmethod
def _get_self_signed_jwt_kwargs(cls, host: str, scopes: Optional[Sequence[str]]) -> Dict[str, Union[Optional[Sequence[str]], str]]:
"""Returns kwargs to pass to grpc_helpers.create_channel depending on the google-api-core version"""

self_signed_jwt_kwargs: Dict[str, Union[Optional[Sequence[str]], str]] = {}

if _API_CORE_VERSION and (
packaging.version.parse(_API_CORE_VERSION)
>= packaging.version.parse("1.26.0")
):
self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
self_signed_jwt_kwargs["scopes"] = scopes
self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
else:
self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES

return self_signed_jwt_kwargs


def _prep_wrapped_messages(self, client_info):
# Precompute the wrapped methods.
self._wrapped_methods = {
Expand Down Expand Up @@ -138,11 +198,11 @@ class {{ service.name }}Transport(abc.ABC):
{%- for method in service.methods.values() %}

@property
def {{ method.name|snake_case }}(self) -> typing.Callable[
def {{ method.name|snake_case }}(self) -> Callable[
[{{ method.input.ident }}],
typing.Union[
Union[
{{ method.output.ident }},
typing.Awaitable[{{ method.output.ident }}]
Awaitable[{{ method.output.ident }}]
]]:
raise NotImplementedError()
{%- endfor %}
Expand All @@ -152,29 +212,29 @@ class {{ service.name }}Transport(abc.ABC):
@property
def set_iam_policy(
self,
) -> typing.Callable[
) -> Callable[
[iam_policy.SetIamPolicyRequest],
typing.Union[policy.Policy, typing.Awaitable[policy.Policy]],
Union[policy.Policy, Awaitable[policy.Policy]],
]:
raise NotImplementedError()

@property
def get_iam_policy(
self,
) -> typing.Callable[
) -> Callable[
[iam_policy.GetIamPolicyRequest],
typing.Union[policy.Policy, typing.Awaitable[policy.Policy]],
Union[policy.Policy, Awaitable[policy.Policy]],
]:
raise NotImplementedError()

@property
def test_iam_permissions(
self,
) -> typing.Callable[
) -> Callable[
[iam_policy.TestIamPermissionsRequest],
typing.Union[
Union[
iam_policy.TestIamPermissionsResponse,
typing.Awaitable[iam_policy.TestIamPermissionsResponse],
Awaitable[iam_policy.TestIamPermissionsResponse],
],
]:
raise NotImplementedError()
Expand Down
Expand Up @@ -2,7 +2,7 @@

{% block content %}
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import grpc_helpers # type: ignore
{%- if service.has_lro %}
Expand Down Expand Up @@ -202,13 +202,15 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
scopes = scopes or cls.AUTH_SCOPES

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
**kwargs
)

Expand Down
Expand Up @@ -2,7 +2,7 @@

{% block content %}
import warnings
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union

from google.api_core import gapic_v1 # type: ignore
from google.api_core import grpc_helpers_async # type: ignore
Expand All @@ -12,6 +12,7 @@ from google.api_core import operations_v1 # type: ignore
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
import packaging.version

import grpc # type: ignore
from grpc.experimental import aio # type: ignore
Expand Down Expand Up @@ -75,13 +76,15 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
Returns:
aio.Channel: A gRPC AsyncIO channel object.
"""
scopes = scopes or cls.AUTH_SCOPES

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(host, scopes)

return grpc_helpers_async.create_channel(
host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes,
quota_project_id=quota_project_id,
**self_signed_jwt_kwargs,
**kwargs
)

Expand Down Expand Up @@ -163,7 +166,6 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
# If a channel was explicitly provided, set it.
self._grpc_channel = channel
self._ssl_channel_credentials = None

else:
if api_mtls_endpoint:
host = api_mtls_endpoint
Expand Down
Expand Up @@ -81,12 +81,14 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
"""
# Run the base constructor
# TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc.
# TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the
# credentials object
super().__init__(
host=host,
credentials=credentials,
client_info=client_info,
)
self._session = AuthorizedSession(self._credentials)
self._session = AuthorizedSession(self._credentials, default_host=self.DEFAULT_HOST)
{%- if service.has_lro %}
self._operations_client = None
{%- endif %}
Expand All @@ -106,11 +108,14 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
# Sanity check: Only create a new client if we do not already have one.
if self._operations_client is None:
from google.api_core import grpc_helpers

self_signed_jwt_kwargs = cls._get_self_signed_jwt_kwargs(self._host, self._scopes)

self._operations_client = operations_v1.OperationsClient(
grpc_helpers.create_channel(
self._host,
credentials=self._credentials,
scopes=self.AUTH_SCOPES,
**self_signed_jwt_kwargs,
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
Expand Down
1 change: 0 additions & 1 deletion gapic/templates/.coveragerc.j2
Expand Up @@ -2,7 +2,6 @@
branch = True

[report]
fail_under = 100
show_missing = True
omit =
{{ api.naming.module_namespace|join("/") }}/{{ api.naming.module_name }}/__init__.py
Expand Down
62 changes: 62 additions & 0 deletions gapic/templates/noxfile.py.j2
Expand Up @@ -2,10 +2,28 @@

{% block content %}
import os
import pathlib
import shutil
import subprocess
import sys


import nox # type: ignore

CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()

LOWER_BOUND_CONSTRAINTS_FILE = CURRENT_DIRECTORY / "constraints.txt"
PACKAGE_NAME = subprocess.check_output([sys.executable, "setup.py", "--name"], encoding="utf-8")


nox.sessions = [
"unit",
"cover",
"mypy",
"check_lower_bounds"
# exclude update_lower_bounds from default
"docs",
]

@nox.session(python=['3.6', '3.7', '3.8', '3.9'])
def unit(session):
Expand All @@ -25,6 +43,18 @@ def unit(session):
)


@nox.session(python='3.7')
def cover(session):
"""Run the final coverage report.
This outputs the coverage report aggregating coverage from the unit
test runs (not system test runs), and then erases coverage data.
"""
session.install("coverage", "pytest-cov")
session.run("coverage", "report", "--show-missing", "--fail-under=100")

session.run("coverage", "erase")


@nox.session(python=['3.6', '3.7'])
def mypy(session):
"""Run the type checker."""
Expand All @@ -40,6 +70,38 @@ def mypy(session):
{%- endif %}
)


@nox.session
def update_lower_bounds(session):
"""Update lower bounds in constraints.txt to match setup.py"""
session.install('google-cloud-testutils')
session.install('.')

session.run(
'lower-bound-checker',
'update',
'--package-name',
PACKAGE_NAME,
'--constraints-file',
str(LOWER_BOUND_CONSTRAINTS_FILE),
)


@nox.session
def check_lower_bounds(session):
"""Check lower bounds in setup.py are reflected in constraints file"""
session.install('google-cloud-testutils')
session.install('.')

session.run(
'lower-bound-checker',
'check',
'--package-name',
PACKAGE_NAME,
'--constraints-file',
str(LOWER_BOUND_CONSTRAINTS_FILE),
)

@nox.session(python='3.6')
def docs(session):
"""Build the docs for this library."""
Expand Down
3 changes: 2 additions & 1 deletion gapic/templates/setup.py.j2
Expand Up @@ -29,8 +29,9 @@ setuptools.setup(
'google-api-core[grpc] >= 1.22.2, < 2.0.0dev',
'libcst >= 0.2.5',
'proto-plus >= 1.15.0',
'packaging >= 14.3',
{%- if api.requires_package(('google', 'iam', 'v1')) or opts.add_iam_methods %}
'grpc-google-iam-v1',
'grpc-google-iam-v1 >= 0.12.3, < 0.13dev',
{%- endif %}
),
python_requires='>=3.6',
Expand Down
2 changes: 2 additions & 0 deletions gapic/templates/tests/__init__.py.j2
@@ -0,0 +1,2 @@

{% extends '_base.py.j2' %}

0 comments on commit 89d6f35

Please sign in to comment.