Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mtls support #367

Merged
merged 9 commits into from Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 18 additions & 3 deletions google/cloud/storage/_http.py
Expand Up @@ -15,27 +15,42 @@
"""Create / interact with Google Cloud Storage connections."""

import functools
import os
import pkg_resources

from google.cloud import _http

from google.cloud.storage import __version__


if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true": # pragma: NO COVER
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
release = pkg_resources.get_distribution("google-cloud-core").parsed_version
if release < pkg_resources.parse_version("1.6.0"):
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError("google-cloud-core >= 1.6.0 is required to use mTLS feature")


class Connection(_http.JSONConnection):
"""A connection to Google Cloud Storage via the JSON REST API.
"""A connection to Google Cloud Storage via the JSON REST API. Mutual TLS feature will be
enabled if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true".

:type client: :class:`~google.cloud.storage.client.Client`
:param client: The client that owns the current connection.

:type client_info: :class:`~google.api_core.client_info.ClientInfo`
:param client_info: (Optional) instance used to generate user agent.

:type api_endpoint: str
:param api_endpoint: (Optional) api endpoint to use.
"""

DEFAULT_API_ENDPOINT = "https://storage.googleapis.com"
DEFAULT_API_MTLS_ENDPOINT = "https://storage.mtls.googleapis.com"

def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT):
def __init__(self, client, client_info=None, api_endpoint=None):
super(Connection, self).__init__(client, client_info)
self.API_BASE_URL = api_endpoint
self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT
self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT
self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None
self._client_info.client_library_version = __version__

# TODO: When metrics all use gccl, this should be removed #9552
Expand Down
21 changes: 17 additions & 4 deletions google/cloud/storage/blob.py
Expand Up @@ -830,9 +830,12 @@ def _get_download_url(
"""
name_value_pairs = []
if self.media_link is None:
base_url = _DOWNLOAD_URL_TEMPLATE.format(
hostname=client._connection.API_BASE_URL, path=self.path
hostname = (
client._connection.API_BASE_URL
if not hasattr(client._connection, "get_api_base_url_for_mtls")
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
else client._connection.get_api_base_url_for_mtls()
)
base_url = _DOWNLOAD_URL_TEMPLATE.format(hostname=hostname, path=self.path)
if self.generation is not None:
name_value_pairs.append(("generation", "{:d}".format(self.generation)))
else:
Expand Down Expand Up @@ -1685,8 +1688,13 @@ def _do_multipart_upload(
info = self._get_upload_arguments(content_type)
headers, object_metadata, content_type = info

hostname = (
client._connection.API_BASE_URL
if not hasattr(client._connection, "get_api_base_url_for_mtls")
else client._connection.get_api_base_url_for_mtls()
)
base_url = _MULTIPART_URL_TEMPLATE.format(
hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path
hostname=hostname, bucket_path=self.bucket.path
)
name_value_pairs = []

Expand Down Expand Up @@ -1866,8 +1874,13 @@ def _initiate_resumable_upload(
if extra_headers is not None:
headers.update(extra_headers)

hostname = (
client._connection.API_BASE_URL
if not hasattr(client._connection, "get_api_base_url_for_mtls")
else client._connection.get_api_base_url_for_mtls()
)
base_url = _RESUMABLE_URL_TEMPLATE.format(
hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path
hostname=hostname, bucket_path=self.bucket.path
)
name_value_pairs = []

Expand Down
10 changes: 9 additions & 1 deletion google/cloud/storage/client.py
Expand Up @@ -32,6 +32,7 @@
from google.cloud.client import ClientWithProject
from google.cloud.exceptions import NotFound
from google.cloud.storage._helpers import _get_storage_host
from google.cloud.storage._helpers import _DEFAULT_STORAGE_HOST
from google.cloud.storage._helpers import _bucket_bound_hostname_url
from google.cloud.storage._http import Connection
from google.cloud.storage._signing import (
Expand Down Expand Up @@ -127,7 +128,14 @@ def __init__(

kw_args = {"client_info": client_info}

kw_args["api_endpoint"] = _get_storage_host()
# `api_endpoint` should be only set by the user via `client_options`,
# or if the _get_storage_host() returns a non-default value.
# `api_endpoint` plays an important role for mTLS, if it is not set,
# then mTLS logic will be applied to decide which endpoint will be used.
storage_host = _get_storage_host()
kw_args["api_endpoint"] = (
storage_host if storage_host != _DEFAULT_STORAGE_HOST else None
)

if client_options:
if type(client_options) == dict:
Expand Down
3 changes: 3 additions & 0 deletions noxfile.py
Expand Up @@ -107,6 +107,9 @@ def system(session):
# Sanity check: Only run tests if the environment variable is set.
if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""):
session.skip("Credentials must be set via environment variable")
# mTLS tests requires pyopenssl.
if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") == "true":
session.install("pyopenssl")

system_test_exists = os.path.exists(system_test_path)
system_test_folder_exists = os.path.exists(system_test_folder_path)
Expand Down
46 changes: 45 additions & 1 deletion tests/system/test_system.py
Expand Up @@ -81,6 +81,7 @@ class Config(object):

CLIENT = None
TEST_BUCKET = None
TESTING_MTLS = False


def setUpModule():
Expand All @@ -91,6 +92,10 @@ def setUpModule():
Config.TEST_BUCKET = Config.CLIENT.bucket(bucket_name)
Config.TEST_BUCKET.versioning_enabled = True
retry_429_503(Config.TEST_BUCKET.create)()
# mTLS testing uses the system test as well. For mTLS testing,
# GOOGLE_API_USE_CLIENT_CERTIFICATE env var will be set to "true"
# explicitly.
Config.TESTING_MTLS = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true"


def tearDownModule():
Expand All @@ -101,6 +106,15 @@ def tearDownModule():


class TestClient(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestClient, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
raise unittest.SkipTest("These tests require a service account credential")

def setUp(self):
self.case_hmac_keys_to_delete = []

Expand Down Expand Up @@ -563,6 +577,15 @@ def tearDown(self):
class TestStorageWriteFiles(TestStorageFiles):
ENCRYPTION_KEY = "b23ff11bba187db8c37077e6af3b25b8"

@classmethod
def setUpClass(cls):
super(TestStorageWriteFiles, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
raise unittest.SkipTest("These tests require a service account credential")

def test_large_file_write_from_stream(self):
blob = self.bucket.blob("LargeFile")

Expand Down Expand Up @@ -1285,11 +1308,14 @@ class TestStorageSignURLs(unittest.TestCase):

@classmethod
def setUpClass(cls):
super(TestStorageSignURLs, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
cls.skipTest("Signing tests requires a service account credential")
raise unittest.SkipTest(
"Signing tests requires a service account credential"
)

bucket_name = "gcp-signing" + unique_resource_id()
cls.bucket = retry_429_503(Config.CLIENT.create_bucket)(bucket_name)
Expand Down Expand Up @@ -1850,6 +1876,12 @@ class TestStorageNotificationCRUD(unittest.TestCase):
CUSTOM_ATTRIBUTES = {"attr1": "value1", "attr2": "value2"}
BLOB_NAME_PREFIX = "blob-name-prefix/"

@classmethod
def setUpClass(cls):
super(TestStorageNotificationCRUD, cls).setUpClass()
if Config.TESTING_MTLS:
raise unittest.SkipTest("Skip pubsub tests for mTLS testing")

@property
def topic_path(self):
return "projects/{}/topics/{}".format(Config.CLIENT.project, self.TOPIC_NAME)
Expand Down Expand Up @@ -2013,6 +2045,9 @@ def _kms_key_name(self, key_name=None):
@classmethod
def setUpClass(cls):
super(TestKMSIntegration, cls).setUpClass()
if Config.TESTING_MTLS:
raise unittest.SkipTest("Skip kms tests for mTLS testing")
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

_empty_bucket(Config.CLIENT, cls.bucket)

def setUp(self):
Expand Down Expand Up @@ -2466,6 +2501,15 @@ def test_ubla_set_unset_preserves_acls(self):


class TestV4POSTPolicies(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(TestV4POSTPolicies, cls).setUpClass()
if (
type(Config.CLIENT._credentials)
is not google.oauth2.service_account.Credentials
):
raise unittest.SkipTest("These tests require a service account credential")
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved

def setUp(self):
self.case_buckets_to_delete = []

Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test__http.py
Expand Up @@ -25,6 +25,8 @@ def _get_target_class():
return Connection

def _make_one(self, *args, **kw):
if "api_endpoint" not in kw:
kw["api_endpoint"] = "https://storage.googleapis.com"
return self._get_target_class()(*args, **kw)

def test_extra_headers(self):
Expand Down Expand Up @@ -213,3 +215,16 @@ def test_api_request_conditional_retry_failed(self):
retry=conditional_retry_mock,
)
http.request.assert_called_once()

def test_mtls(self):
client = object()

conn = self._make_one(client, api_endpoint=None)
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True)
self.assertEqual(conn.API_BASE_URL, "https://storage.googleapis.com")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://storage.mtls.googleapis.com")

conn = self._make_one(client, api_endpoint="http://foo")
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False)
self.assertEqual(conn.API_BASE_URL, "http://foo")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://storage.mtls.googleapis.com")
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
57 changes: 57 additions & 0 deletions tests/unit/test_blob.py
Expand Up @@ -909,6 +909,24 @@ def test__get_download_url_on_the_fly(self):
)
self.assertEqual(download_url, expected_url)

def test__get_download_url_mtls(self):
blob_name = "bzzz-fly.txt"
bucket = _Bucket(name="buhkit")
blob = self._make_one(blob_name, bucket=bucket)

self.assertIsNone(blob.media_link)
client = mock.Mock(_connection=_Connection)
client._connection.API_BASE_URL = "https://storage.googleapis.com"
client._connection.get_api_base_url_for_mtls = mock.Mock(
return_value="https://foo.mtls"
)
download_url = blob._get_download_url(client)
del client._connection.get_api_base_url_for_mtls
expected_url = (
"https://foo.mtls/download/storage/v1/b/" "buhkit/o/bzzz-fly.txt?alt=media"
)
self.assertEqual(download_url, expected_url)

def test__get_download_url_on_the_fly_with_generation(self):
blob_name = "pretend.txt"
bucket = _Bucket(name="fictional")
Expand Down Expand Up @@ -1959,6 +1977,7 @@ def _do_multipart_success(
kms_key_name=None,
timeout=None,
metadata=None,
mtls=False,
):
from six.moves.urllib.parse import urlencode

Expand All @@ -1977,6 +1996,14 @@ def _do_multipart_success(

client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"])
client._connection.API_BASE_URL = "https://storage.googleapis.com"

# Mock get_api_base_url_for_mtls function.
mtls_url = "https://foo.mtls"
if mtls:
client._connection.get_api_base_url_for_mtls = mock.Mock(
return_value=mtls_url
)

data = b"data here hear hier"
stream = io.BytesIO(data)
content_type = u"application/xml"
Expand All @@ -2002,6 +2029,10 @@ def _do_multipart_success(
**timeout_kwarg
)

# Clean up the get_api_base_url_for_mtls mock.
if mtls:
del client._connection.get_api_base_url_for_mtls

# Check the mocks and the returned value.
self.assertIs(response, client._http.request.return_value)
if size is None:
Expand All @@ -2016,6 +2047,8 @@ def _do_multipart_success(
upload_url = (
"https://storage.googleapis.com/upload/storage/v1" + bucket.path + "/o"
)
if mtls:
upload_url = mtls_url + "/upload/storage/v1" + bucket.path + "/o"

qs_params = [("uploadType", "multipart")]

Expand Down Expand Up @@ -2064,6 +2097,12 @@ def _do_multipart_success(
def test__do_multipart_upload_no_size(self, mock_get_boundary):
self._do_multipart_success(mock_get_boundary, predefined_acl="private")

@mock.patch(u"google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_no_size_mtls(self, mock_get_boundary):
self._do_multipart_success(
mock_get_boundary, predefined_acl="private", mtls=True
)

@mock.patch(u"google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_with_size(self, mock_get_boundary):
self._do_multipart_success(mock_get_boundary, size=10)
Expand Down Expand Up @@ -2159,6 +2198,7 @@ def _initiate_resumable_helper(
kms_key_name=None,
timeout=None,
metadata=None,
mtls=False,
):
from six.moves.urllib.parse import urlencode
from google.resumable_media.requests import ResumableUpload
Expand Down Expand Up @@ -2197,6 +2237,14 @@ def _initiate_resumable_helper(
_http=transport, _connection=_Connection, spec=[u"_http"]
)
client._connection.API_BASE_URL = "https://storage.googleapis.com"

# Mock get_api_base_url_for_mtls function.
mtls_url = "https://foo.mtls"
if mtls:
client._connection.get_api_base_url_for_mtls = mock.Mock(
return_value=mtls_url
)

data = b"hello hallo halo hi-low"
stream = io.BytesIO(data)
content_type = u"text/plain"
Expand Down Expand Up @@ -2224,12 +2272,18 @@ def _initiate_resumable_helper(
**timeout_kwarg
)

# Clean up the get_api_base_url_for_mtls mock.
if mtls:
del client._connection.get_api_base_url_for_mtls

# Check the returned values.
self.assertIsInstance(upload, ResumableUpload)

upload_url = (
"https://storage.googleapis.com/upload/storage/v1" + bucket.path + "/o"
)
if mtls:
upload_url = mtls_url + "/upload/storage/v1" + bucket.path + "/o"
qs_params = [("uploadType", "resumable")]

if user_project is not None:
Expand Down Expand Up @@ -2322,6 +2376,9 @@ def test__initiate_resumable_upload_with_custom_timeout(self):
def test__initiate_resumable_upload_no_size(self):
self._initiate_resumable_helper()

def test__initiate_resumable_upload_no_size_mtls(self):
self._initiate_resumable_helper(mtls=True)

def test__initiate_resumable_upload_with_size(self):
self._initiate_resumable_helper(size=10000)

Expand Down