Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
feat: add mtls support (#367)
* feat: add mtls support

* update

* update

* update

* update

* update
  • Loading branch information
arithmetic1728 committed Feb 10, 2021
1 parent 1dc6d64 commit d35ab35
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 10 deletions.
21 changes: 18 additions & 3 deletions google/cloud/storage/_http.py
Expand Up @@ -15,27 +15,42 @@
"""Create / interact with Google Cloud Storage connections."""

import functools
import os
import pkg_resources

from google.cloud import _http

from google.cloud.storage import __version__


if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true": # pragma: NO COVER
release = pkg_resources.get_distribution("google-cloud-core").parsed_version
if release < pkg_resources.parse_version("1.6.0"):
raise ImportError("google-cloud-core >= 1.6.0 is required to use mTLS feature")


class Connection(_http.JSONConnection):
"""A connection to Google Cloud Storage via the JSON REST API.
"""A connection to Google Cloud Storage via the JSON REST API. Mutual TLS feature will be
enabled if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true".
:type client: :class:`~google.cloud.storage.client.Client`
:param client: The client that owns the current connection.
:type client_info: :class:`~google.api_core.client_info.ClientInfo`
:param client_info: (Optional) instance used to generate user agent.
:type api_endpoint: str
:param api_endpoint: (Optional) api endpoint to use.
"""

DEFAULT_API_ENDPOINT = "https://storage.googleapis.com"
DEFAULT_API_MTLS_ENDPOINT = "https://storage.mtls.googleapis.com"

def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT):
def __init__(self, client, client_info=None, api_endpoint=None):
super(Connection, self).__init__(client, client_info)
self.API_BASE_URL = api_endpoint
self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT
self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT
self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None
self._client_info.client_library_version = __version__

# TODO: When metrics all use gccl, this should be removed #9552
Expand Down
30 changes: 25 additions & 5 deletions google/cloud/storage/blob.py
Expand Up @@ -830,9 +830,8 @@ def _get_download_url(
"""
name_value_pairs = []
if self.media_link is None:
base_url = _DOWNLOAD_URL_TEMPLATE.format(
hostname=client._connection.API_BASE_URL, path=self.path
)
hostname = _get_host_name(client._connection)
base_url = _DOWNLOAD_URL_TEMPLATE.format(hostname=hostname, path=self.path)
if self.generation is not None:
name_value_pairs.append(("generation", "{:d}".format(self.generation)))
else:
Expand Down Expand Up @@ -1685,8 +1684,9 @@ def _do_multipart_upload(
info = self._get_upload_arguments(content_type)
headers, object_metadata, content_type = info

hostname = _get_host_name(client._connection)
base_url = _MULTIPART_URL_TEMPLATE.format(
hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path
hostname=hostname, bucket_path=self.bucket.path
)
name_value_pairs = []

Expand Down Expand Up @@ -1866,8 +1866,9 @@ def _initiate_resumable_upload(
if extra_headers is not None:
headers.update(extra_headers)

hostname = _get_host_name(client._connection)
base_url = _RESUMABLE_URL_TEMPLATE.format(
hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path
hostname=hostname, bucket_path=self.bucket.path
)
name_value_pairs = []

Expand Down Expand Up @@ -3798,6 +3799,25 @@ def custom_time(self, value):
self._patch_property("customTime", value)


def _get_host_name(connection):
"""Returns the host name from the given connection.
:type connection: :class:`~google.cloud.storage._http.Connection`
:param connection: The connection object.
:rtype: str
:returns: The host name.
"""
# TODO: After google-cloud-core 1.6.0 is stable and we upgrade it
# to 1.6.0 in setup.py, we no longer need to check the attribute
# existence. We can simply return connection.get_api_base_url_for_mtls().
return (
connection.API_BASE_URL
if not hasattr(connection, "get_api_base_url_for_mtls")
else connection.get_api_base_url_for_mtls()
)


def _get_encryption_headers(key, source=False):
"""Builds customer encryption key headers
Expand Down
10 changes: 9 additions & 1 deletion google/cloud/storage/client.py
Expand Up @@ -32,6 +32,7 @@
from google.cloud.client import ClientWithProject
from google.cloud.exceptions import NotFound
from google.cloud.storage._helpers import _get_storage_host
from google.cloud.storage._helpers import _DEFAULT_STORAGE_HOST
from google.cloud.storage._helpers import _bucket_bound_hostname_url
from google.cloud.storage._http import Connection
from google.cloud.storage._signing import (
Expand Down Expand Up @@ -127,7 +128,14 @@ def __init__(

kw_args = {"client_info": client_info}

kw_args["api_endpoint"] = _get_storage_host()
# `api_endpoint` should be only set by the user via `client_options`,
# or if the _get_storage_host() returns a non-default value.
# `api_endpoint` plays an important role for mTLS, if it is not set,
# then mTLS logic will be applied to decide which endpoint will be used.
storage_host = _get_storage_host()
kw_args["api_endpoint"] = (
storage_host if storage_host != _DEFAULT_STORAGE_HOST else None
)

if client_options:
if type(client_options) == dict:
Expand Down
3 changes: 3 additions & 0 deletions noxfile.py
Expand Up @@ -107,6 +107,9 @@ def system(session):
# Sanity check: Only run tests if the environment variable is set.
if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""):
session.skip("Credentials must be set via environment variable")
# mTLS tests requires pyopenssl.
if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") == "true":
session.install("pyopenssl")

system_test_exists = os.path.exists(system_test_path)
system_test_folder_exists = os.path.exists(system_test_folder_path)
Expand Down
60 changes: 59 additions & 1 deletion tests/system/test_system.py
Expand Up @@ -81,6 +81,7 @@ class Config(object):

CLIENT = None
TEST_BUCKET = None
TESTING_MTLS = False


def setUpModule():
Expand All @@ -91,6 +92,10 @@ def setUpModule():
Config.TEST_BUCKET = Config.CLIENT.bucket(bucket_name)
Config.TEST_BUCKET.versioning_enabled = True
retry_429_503(Config.TEST_BUCKET.create)()
# mTLS testing uses the system test as well. For mTLS testing,
# GOOGLE_API_USE_CLIENT_CERTIFICATE env var will be set to "true"
# explicitly.
Config.TESTING_MTLS = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true"


def tearDownModule():
Expand All @@ -101,6 +106,15 @@ def tearDownModule():


class TestClient(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestClient, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
raise unittest.SkipTest("These tests require a service account credential")

def setUp(self):
self.case_hmac_keys_to_delete = []

Expand Down Expand Up @@ -563,6 +577,15 @@ def tearDown(self):
class TestStorageWriteFiles(TestStorageFiles):
ENCRYPTION_KEY = "b23ff11bba187db8c37077e6af3b25b8"

@classmethod
def setUpClass(cls):
super(TestStorageWriteFiles, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
raise unittest.SkipTest("These tests require a service account credential")

def test_large_file_write_from_stream(self):
blob = self.bucket.blob("LargeFile")

Expand Down Expand Up @@ -1285,11 +1308,14 @@ class TestStorageSignURLs(unittest.TestCase):

@classmethod
def setUpClass(cls):
super(TestStorageSignURLs, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
cls.skipTest("Signing tests requires a service account credential")
raise unittest.SkipTest(
"Signing tests requires a service account credential"
)

bucket_name = "gcp-signing" + unique_resource_id()
cls.bucket = retry_429_503(Config.CLIENT.create_bucket)(bucket_name)
Expand Down Expand Up @@ -1850,6 +1876,18 @@ class TestStorageNotificationCRUD(unittest.TestCase):
CUSTOM_ATTRIBUTES = {"attr1": "value1", "attr2": "value2"}
BLOB_NAME_PREFIX = "blob-name-prefix/"

@classmethod
def setUpClass(cls):
super(TestStorageNotificationCRUD, cls).setUpClass()
if Config.TESTING_MTLS:
# mTLS is only available for python-pubsub >= 2.2.0. However, the
# system test uses python-pubsub < 2.0, so we skip those tests.
# Note that python-pubsub >= 2.0 no longer supports python 2.7, so
# we can only upgrade it after python 2.7 system test is removed.
# Since python-pubsub >= 2.0 has a new set of api, the test code
# also needs to be updated.
raise unittest.SkipTest("Skip pubsub tests for mTLS testing")

@property
def topic_path(self):
return "projects/{}/topics/{}".format(Config.CLIENT.project, self.TOPIC_NAME)
Expand Down Expand Up @@ -2013,6 +2051,15 @@ def _kms_key_name(self, key_name=None):
@classmethod
def setUpClass(cls):
super(TestKMSIntegration, cls).setUpClass()
if Config.TESTING_MTLS:
# mTLS is only available for python-kms >= 2.2.0. However, the
# system test uses python-kms < 2.0, so we skip those tests.
# Note that python-kms >= 2.0 no longer supports python 2.7, so
# we can only upgrade it after python 2.7 system test is removed.
# Since python-kms >= 2.0 has a new set of api, the test code
# also needs to be updated.
raise unittest.SkipTest("Skip kms tests for mTLS testing")

_empty_bucket(Config.CLIENT, cls.bucket)

def setUp(self):
Expand Down Expand Up @@ -2466,6 +2513,17 @@ def test_ubla_set_unset_preserves_acls(self):


class TestV4POSTPolicies(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestV4POSTPolicies, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
# mTLS only works for user credentials, it doesn't work for
# service account credentials.
raise unittest.SkipTest("These tests require a service account credential")

def setUp(self):
self.case_buckets_to_delete = []

Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test__http.py
Expand Up @@ -25,6 +25,8 @@ def _get_target_class():
return Connection

def _make_one(self, *args, **kw):
if "api_endpoint" not in kw:
kw["api_endpoint"] = "https://storage.googleapis.com"
return self._get_target_class()(*args, **kw)

def test_extra_headers(self):
Expand Down Expand Up @@ -213,3 +215,16 @@ def test_api_request_conditional_retry_failed(self):
retry=conditional_retry_mock,
)
http.request.assert_called_once()

def test_mtls(self):
client = object()

conn = self._make_one(client, api_endpoint=None)
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True)
self.assertEqual(conn.API_BASE_URL, "https://storage.googleapis.com")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://storage.mtls.googleapis.com")

conn = self._make_one(client, api_endpoint="http://foo")
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False)
self.assertEqual(conn.API_BASE_URL, "http://foo")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://storage.mtls.googleapis.com")

0 comments on commit d35ab35

Please sign in to comment.