Skip to content

Commit

Permalink
fix: make gRPC auth plugin non-blocking + add default timeout value f…
Browse files Browse the repository at this point in the history
…or requests transport (#390)

This commit includes the following changes:
- `transport.grpc.AuthMetadataPlugin` is now non-blocking as gRPC requires
- `transport.requests.Request` now has a default timeout value of 120 seconds so that token refreshing will not be stuck

Resolves: #351
  • Loading branch information
chenyumic committed Nov 25, 2019
1 parent 3a46178 commit 0c33e9c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
16 changes: 15 additions & 1 deletion google/auth/transport/grpc.py
Expand Up @@ -16,6 +16,8 @@

from __future__ import absolute_import

from concurrent import futures

import six

try:
Expand Down Expand Up @@ -51,6 +53,7 @@ def __init__(self, credentials, request):
super(AuthMetadataPlugin, self).__init__()
self._credentials = credentials
self._request = request
self._pool = futures.ThreadPoolExecutor(max_workers=1)

def _get_authorization_headers(self, context):
"""Gets the authorization headers for a request.
Expand All @@ -66,6 +69,13 @@ def _get_authorization_headers(self, context):

return list(six.iteritems(headers))

@staticmethod
def _callback_wrapper(callback):
def wrapped(future):
callback(future.result(), None)

return wrapped

def __call__(self, context, callback):
"""Passes authorization metadata into the given callback.
Expand All @@ -74,7 +84,11 @@ def __call__(self, context, callback):
callback (grpc.AuthMetadataPluginCallback): The callback that will
be invoked to pass in the authorization metadata.
"""
callback(self._get_authorization_headers(context), None)
future = self._pool.submit(self._get_authorization_headers, context)
future.add_done_callback(self._callback_wrapper(callback))

def __del__(self):
self._pool.shutdown(wait=False)


def secure_authorized_channel(
Expand Down
2 changes: 1 addition & 1 deletion google/auth/transport/requests.py
Expand Up @@ -95,7 +95,7 @@ def __init__(self, session=None):
self.session = session

def __call__(
self, url, method="GET", body=None, headers=None, timeout=None, **kwargs
self, url, method="GET", body=None, headers=None, timeout=120, **kwargs
):
"""Make an HTTP request using requests.
Expand Down
5 changes: 5 additions & 0 deletions tests/transport/test_grpc.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import datetime
import time

import mock
import pytest
Expand Down Expand Up @@ -58,6 +59,8 @@ def test_call_no_refresh(self):

plugin(context, callback)

time.sleep(2)

callback.assert_called_once_with(
[(u"authorization", u"Bearer {}".format(credentials.token))], None
)
Expand All @@ -76,6 +79,8 @@ def test_call_refresh(self):

plugin(context, callback)

time.sleep(2)

assert credentials.token == "token1"
callback.assert_called_once_with(
[(u"authorization", u"Bearer {}".format(credentials.token))], None
Expand Down

0 comments on commit 0c33e9c

Please sign in to comment.