Skip to content

Commit

Permalink
feat: add mtls support (#75)
Browse files Browse the repository at this point in the history
* feat: add mtls support

* update

* update

* update

* chore: update

* update

Co-authored-by: Bu Sun Kim <8822365+busunkim96@users.noreply.github.com>
  • Loading branch information
arithmetic1728 and busunkim96 committed Feb 2, 2021
1 parent a0a4a28 commit a93129b
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 10 deletions.
47 changes: 46 additions & 1 deletion google/cloud/_http.py
Expand Up @@ -20,6 +20,7 @@
except ImportError:
import collections as collections_abc
import json
import os
import platform
import warnings

Expand Down Expand Up @@ -176,12 +177,56 @@ class JSONConnection(Connection):
API_BASE_URL = None
"""The base of the API call URL."""

API_BASE_MTLS_URL = None
"""The base of the API call URL for mutual TLS."""

ALLOW_AUTO_SWITCH_TO_MTLS_URL = False
"""Indicates if auto switch to mTLS url is allowed."""

API_VERSION = None
"""The version of the API, used in building the API call's URL."""

API_URL_TEMPLATE = None
"""A template for the URL of a particular API call."""

def get_api_base_url_for_mtls(self, api_base_url=None):
"""Return the api base url for mutual TLS.
Typically, you shouldn't need to use this method.
The logic is as follows:
If `api_base_url` is provided, just return this value; otherwise, the
return value depends `GOOGLE_API_USE_MTLS_ENDPOINT` environment variable
value.
If the environment variable value is "always", return `API_BASE_MTLS_URL`.
If the environment variable value is "never", return `API_BASE_URL`.
Otherwise, if `ALLOW_AUTO_SWITCH_TO_MTLS_URL` is True and the underlying
http is mTLS, then return `API_BASE_MTLS_URL`; otherwise return `API_BASE_URL`.
:type api_base_url: str
:param api_base_url: User provided api base url. It takes precedence over
`API_BASE_URL` and `API_BASE_MTLS_URL`.
:rtype: str
:returns: The api base url used for mTLS.
"""
if api_base_url:
return api_base_url

env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto")
if env == "always":
url_to_use = self.API_BASE_MTLS_URL
elif env == "never":
url_to_use = self.API_BASE_URL
else:
if self.ALLOW_AUTO_SWITCH_TO_MTLS_URL:
url_to_use = self.API_BASE_MTLS_URL if self.http.is_mtls else self.API_BASE_URL
else:
url_to_use = self.API_BASE_URL
return url_to_use

def build_api_url(
self, path, query_params=None, api_base_url=None, api_version=None
):
Expand Down Expand Up @@ -210,7 +255,7 @@ def build_api_url(
:returns: The URL assembled from the pieces provided.
"""
url = self.API_URL_TEMPLATE.format(
api_base_url=(api_base_url or self.API_BASE_URL),
api_base_url=self.get_api_base_url_for_mtls(api_base_url),
api_version=(api_version or self.API_VERSION),
path=path,
)
Expand Down
2 changes: 2 additions & 0 deletions google/cloud/client.py
Expand Up @@ -159,6 +159,7 @@ def __init__(self, credentials=None, _http=None, client_options=None):
self._credentials = self._credentials.with_quota_project(client_options.quota_project_id)

self._http_internal = _http
self._client_cert_source = client_options.client_cert_source

def __getstate__(self):
"""Explicitly state that clients are not pickleable."""
Expand All @@ -183,6 +184,7 @@ def _http(self):
self._credentials,
refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT,
)
self._http_internal.configure_mtls_channel(self._client_cert_source)
return self._http_internal


Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -29,6 +29,7 @@
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
"google-api-core >= 1.21.0, < 2.0.0dev",
"google-auth >= 1.24.0, < 2.0dev",
# Support six==1.12.0 due to App Engine standard runtime.
# https://github.com/googleapis/python-cloud-core/issues/45
"six >=1.12.0",
Expand Down
46 changes: 46 additions & 0 deletions tests/unit/test__http.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import os
import unittest
import warnings

Expand Down Expand Up @@ -165,6 +166,7 @@ def _make_mock_one(self, *args, **kw):
class MockConnection(self._get_target_class()):
API_URL_TEMPLATE = "{api_base_url}/mock/{api_version}{path}"
API_BASE_URL = "http://mock"
API_BASE_MTLS_URL = "https://mock.mtls"
API_VERSION = "vMOCK"

return MockConnection(*args, **kw)
Expand Down Expand Up @@ -230,6 +232,50 @@ def test_build_api_url_w_extra_query_params_tuples(self):
self.assertEqual(parms["qux"], ["quux", "corge"])
self.assertEqual(parms["prettyPrint"], ["false"])

def test_get_api_base_url_for_mtls_w_api_base_url(self):
client = object()
conn = self._make_mock_one(client)
uri = conn.get_api_base_url_for_mtls(api_base_url="http://foo")
self.assertEqual(uri, "http://foo")

def test_get_api_base_url_for_mtls_env_always(self):
client = object()
conn = self._make_mock_one(client)
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "https://mock.mtls")

def test_get_api_base_url_for_mtls_env_never(self):
client = object()
conn = self._make_mock_one(client)
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "http://mock")

def test_get_api_base_url_for_mtls_env_auto(self):
client = mock.Mock()
client._http = mock.Mock()
client._http.is_mtls = False
conn = self._make_mock_one(client)

# ALLOW_AUTO_SWITCH_TO_MTLS_URL is False, so use regular endpoint.
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "http://mock")

# ALLOW_AUTO_SWITCH_TO_MTLS_URL is True, so now endpoint dependes
# on client._http.is_mtls
conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL = True

with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "http://mock")

client._http.is_mtls = True
with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}):
uri = conn.get_api_base_url_for_mtls()
self.assertEqual(uri, "https://mock.mtls")

def test__make_request_no_data_no_content_type_no_headers(self):
from google.cloud._http import CLIENT_INFO_HEADER

Expand Down
20 changes: 11 additions & 9 deletions tests/unit/test_client.py
Expand Up @@ -125,20 +125,22 @@ def test_ctor__http_property_new(self):
from google.cloud.client import _CREDENTIALS_REFRESH_TIMEOUT

credentials = _make_credentials()
client = self._make_one(credentials=credentials)
mock_client_cert_source = mock.Mock()
client_options = {'client_cert_source': mock_client_cert_source}
client = self._make_one(credentials=credentials, client_options=client_options)
self.assertIsNone(client._http_internal)

authorized_session_patch = mock.patch(
"google.auth.transport.requests.AuthorizedSession",
return_value=mock.sentinel.http,
)
with authorized_session_patch as AuthorizedSession:
self.assertIs(client._http, mock.sentinel.http)
with mock.patch('google.auth.transport.requests.AuthorizedSession') as AuthorizedSession:
session = mock.Mock()
session.configure_mtls_channel = mock.Mock()
AuthorizedSession.return_value = session
self.assertIs(client._http, session)
# Check the mock.
AuthorizedSession.assert_called_once_with(credentials, refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT)
session.configure_mtls_channel.assert_called_once_with(mock_client_cert_source)
# Make sure the cached value is used on subsequent access.
self.assertIs(client._http_internal, mock.sentinel.http)
self.assertIs(client._http, mock.sentinel.http)
self.assertIs(client._http_internal, session)
self.assertIs(client._http, session)
self.assertEqual(AuthorizedSession.call_count, 1)

def test_from_service_account_json(self):
Expand Down

0 comments on commit a93129b

Please sign in to comment.