diff --git a/google/cloud/storage/_http.py b/google/cloud/storage/_http.py index 6e175196c..0dcc68cdb 100644 --- a/google/cloud/storage/_http.py +++ b/google/cloud/storage/_http.py @@ -15,27 +15,42 @@ """Create / interact with Google Cloud Storage connections.""" import functools +import os +import pkg_resources from google.cloud import _http from google.cloud.storage import __version__ +if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true": # pragma: NO COVER + release = pkg_resources.get_distribution("google-cloud-core").parsed_version + if release < pkg_resources.parse_version("1.6.0"): + raise ImportError("google-cloud-core >= 1.6.0 is required to use mTLS feature") + + class Connection(_http.JSONConnection): - """A connection to Google Cloud Storage via the JSON REST API. + """A connection to Google Cloud Storage via the JSON REST API. Mutual TLS feature will be + enabled if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is set to "true". :type client: :class:`~google.cloud.storage.client.Client` :param client: The client that owns the current connection. :type client_info: :class:`~google.api_core.client_info.ClientInfo` :param client_info: (Optional) instance used to generate user agent. + + :type api_endpoint: str + :param api_endpoint: (Optional) api endpoint to use. """ DEFAULT_API_ENDPOINT = "https://storage.googleapis.com" + DEFAULT_API_MTLS_ENDPOINT = "https://storage.mtls.googleapis.com" - def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT): + def __init__(self, client, client_info=None, api_endpoint=None): super(Connection, self).__init__(client, client_info) - self.API_BASE_URL = api_endpoint + self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT + self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT + self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None self._client_info.client_library_version = __version__ # TODO: When metrics all use gccl, this should be removed #9552 diff --git a/google/cloud/storage/blob.py b/google/cloud/storage/blob.py index f1cb5666b..119e3318f 100644 --- a/google/cloud/storage/blob.py +++ b/google/cloud/storage/blob.py @@ -830,9 +830,8 @@ def _get_download_url( """ name_value_pairs = [] if self.media_link is None: - base_url = _DOWNLOAD_URL_TEMPLATE.format( - hostname=client._connection.API_BASE_URL, path=self.path - ) + hostname = _get_host_name(client._connection) + base_url = _DOWNLOAD_URL_TEMPLATE.format(hostname=hostname, path=self.path) if self.generation is not None: name_value_pairs.append(("generation", "{:d}".format(self.generation))) else: @@ -1685,8 +1684,9 @@ def _do_multipart_upload( info = self._get_upload_arguments(content_type) headers, object_metadata, content_type = info + hostname = _get_host_name(client._connection) base_url = _MULTIPART_URL_TEMPLATE.format( - hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path + hostname=hostname, bucket_path=self.bucket.path ) name_value_pairs = [] @@ -1866,8 +1866,9 @@ def _initiate_resumable_upload( if extra_headers is not None: headers.update(extra_headers) + hostname = _get_host_name(client._connection) base_url = _RESUMABLE_URL_TEMPLATE.format( - hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path + hostname=hostname, bucket_path=self.bucket.path ) name_value_pairs = [] @@ -3798,6 +3799,25 @@ def custom_time(self, value): self._patch_property("customTime", value) +def _get_host_name(connection): + """Returns the host name from the given connection. + + :type connection: :class:`~google.cloud.storage._http.Connection` + :param connection: The connection object. + + :rtype: str + :returns: The host name. + """ + # TODO: After google-cloud-core 1.6.0 is stable and we upgrade it + # to 1.6.0 in setup.py, we no longer need to check the attribute + # existence. We can simply return connection.get_api_base_url_for_mtls(). + return ( + connection.API_BASE_URL + if not hasattr(connection, "get_api_base_url_for_mtls") + else connection.get_api_base_url_for_mtls() + ) + + def _get_encryption_headers(key, source=False): """Builds customer encryption key headers diff --git a/google/cloud/storage/client.py b/google/cloud/storage/client.py index 8812dc32e..36ee6b9f2 100644 --- a/google/cloud/storage/client.py +++ b/google/cloud/storage/client.py @@ -32,6 +32,7 @@ from google.cloud.client import ClientWithProject from google.cloud.exceptions import NotFound from google.cloud.storage._helpers import _get_storage_host +from google.cloud.storage._helpers import _DEFAULT_STORAGE_HOST from google.cloud.storage._helpers import _bucket_bound_hostname_url from google.cloud.storage._http import Connection from google.cloud.storage._signing import ( @@ -127,7 +128,14 @@ def __init__( kw_args = {"client_info": client_info} - kw_args["api_endpoint"] = _get_storage_host() + # `api_endpoint` should be only set by the user via `client_options`, + # or if the _get_storage_host() returns a non-default value. + # `api_endpoint` plays an important role for mTLS, if it is not set, + # then mTLS logic will be applied to decide which endpoint will be used. + storage_host = _get_storage_host() + kw_args["api_endpoint"] = ( + storage_host if storage_host != _DEFAULT_STORAGE_HOST else None + ) if client_options: if type(client_options) == dict: diff --git a/noxfile.py b/noxfile.py index 0609a80f6..d9efc06ec 100644 --- a/noxfile.py +++ b/noxfile.py @@ -107,6 +107,9 @@ def system(session): # Sanity check: Only run tests if the environment variable is set. if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): session.skip("Credentials must be set via environment variable") + # mTLS tests requires pyopenssl. + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "") == "true": + session.install("pyopenssl") system_test_exists = os.path.exists(system_test_path) system_test_folder_exists = os.path.exists(system_test_folder_path) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index ec05b0c72..bb6ff5b54 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -81,6 +81,7 @@ class Config(object): CLIENT = None TEST_BUCKET = None + TESTING_MTLS = False def setUpModule(): @@ -91,6 +92,10 @@ def setUpModule(): Config.TEST_BUCKET = Config.CLIENT.bucket(bucket_name) Config.TEST_BUCKET.versioning_enabled = True retry_429_503(Config.TEST_BUCKET.create)() + # mTLS testing uses the system test as well. For mTLS testing, + # GOOGLE_API_USE_CLIENT_CERTIFICATE env var will be set to "true" + # explicitly. + Config.TESTING_MTLS = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true" def tearDownModule(): @@ -101,6 +106,15 @@ def tearDownModule(): class TestClient(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(TestClient, cls).setUpClass() + if ( + type(Config.CLIENT._credentials) + is not google.oauth2.service_account.Credentials + ): + raise unittest.SkipTest("These tests require a service account credential") + def setUp(self): self.case_hmac_keys_to_delete = [] @@ -563,6 +577,15 @@ def tearDown(self): class TestStorageWriteFiles(TestStorageFiles): ENCRYPTION_KEY = "b23ff11bba187db8c37077e6af3b25b8" + @classmethod + def setUpClass(cls): + super(TestStorageWriteFiles, cls).setUpClass() + if ( + type(Config.CLIENT._credentials) + is not google.oauth2.service_account.Credentials + ): + raise unittest.SkipTest("These tests require a service account credential") + def test_large_file_write_from_stream(self): blob = self.bucket.blob("LargeFile") @@ -1285,11 +1308,14 @@ class TestStorageSignURLs(unittest.TestCase): @classmethod def setUpClass(cls): + super(TestStorageSignURLs, cls).setUpClass() if ( type(Config.CLIENT._credentials) is not google.oauth2.service_account.Credentials ): - cls.skipTest("Signing tests requires a service account credential") + raise unittest.SkipTest( + "Signing tests requires a service account credential" + ) bucket_name = "gcp-signing" + unique_resource_id() cls.bucket = retry_429_503(Config.CLIENT.create_bucket)(bucket_name) @@ -1850,6 +1876,18 @@ class TestStorageNotificationCRUD(unittest.TestCase): CUSTOM_ATTRIBUTES = {"attr1": "value1", "attr2": "value2"} BLOB_NAME_PREFIX = "blob-name-prefix/" + @classmethod + def setUpClass(cls): + super(TestStorageNotificationCRUD, cls).setUpClass() + if Config.TESTING_MTLS: + # mTLS is only available for python-pubsub >= 2.2.0. However, the + # system test uses python-pubsub < 2.0, so we skip those tests. + # Note that python-pubsub >= 2.0 no longer supports python 2.7, so + # we can only upgrade it after python 2.7 system test is removed. + # Since python-pubsub >= 2.0 has a new set of api, the test code + # also needs to be updated. + raise unittest.SkipTest("Skip pubsub tests for mTLS testing") + @property def topic_path(self): return "projects/{}/topics/{}".format(Config.CLIENT.project, self.TOPIC_NAME) @@ -2013,6 +2051,15 @@ def _kms_key_name(self, key_name=None): @classmethod def setUpClass(cls): super(TestKMSIntegration, cls).setUpClass() + if Config.TESTING_MTLS: + # mTLS is only available for python-kms >= 2.2.0. However, the + # system test uses python-kms < 2.0, so we skip those tests. + # Note that python-kms >= 2.0 no longer supports python 2.7, so + # we can only upgrade it after python 2.7 system test is removed. + # Since python-kms >= 2.0 has a new set of api, the test code + # also needs to be updated. + raise unittest.SkipTest("Skip kms tests for mTLS testing") + _empty_bucket(Config.CLIENT, cls.bucket) def setUp(self): @@ -2466,6 +2513,17 @@ def test_ubla_set_unset_preserves_acls(self): class TestV4POSTPolicies(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(TestV4POSTPolicies, cls).setUpClass() + if ( + type(Config.CLIENT._credentials) + is not google.oauth2.service_account.Credentials + ): + # mTLS only works for user credentials, it doesn't work for + # service account credentials. + raise unittest.SkipTest("These tests require a service account credential") + def setUp(self): self.case_buckets_to_delete = [] diff --git a/tests/unit/test__http.py b/tests/unit/test__http.py index 00cb4d34e..ac8471d07 100644 --- a/tests/unit/test__http.py +++ b/tests/unit/test__http.py @@ -25,6 +25,8 @@ def _get_target_class(): return Connection def _make_one(self, *args, **kw): + if "api_endpoint" not in kw: + kw["api_endpoint"] = "https://storage.googleapis.com" return self._get_target_class()(*args, **kw) def test_extra_headers(self): @@ -213,3 +215,16 @@ def test_api_request_conditional_retry_failed(self): retry=conditional_retry_mock, ) http.request.assert_called_once() + + def test_mtls(self): + client = object() + + conn = self._make_one(client, api_endpoint=None) + self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True) + self.assertEqual(conn.API_BASE_URL, "https://storage.googleapis.com") + self.assertEqual(conn.API_BASE_MTLS_URL, "https://storage.mtls.googleapis.com") + + conn = self._make_one(client, api_endpoint="http://foo") + self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False) + self.assertEqual(conn.API_BASE_URL, "http://foo") + self.assertEqual(conn.API_BASE_MTLS_URL, "https://storage.mtls.googleapis.com") diff --git a/tests/unit/test_blob.py b/tests/unit/test_blob.py index cd6ecafa0..4aacc3a8c 100644 --- a/tests/unit/test_blob.py +++ b/tests/unit/test_blob.py @@ -909,6 +909,24 @@ def test__get_download_url_on_the_fly(self): ) self.assertEqual(download_url, expected_url) + def test__get_download_url_mtls(self): + blob_name = "bzzz-fly.txt" + bucket = _Bucket(name="buhkit") + blob = self._make_one(blob_name, bucket=bucket) + + self.assertIsNone(blob.media_link) + client = mock.Mock(_connection=_Connection) + client._connection.API_BASE_URL = "https://storage.googleapis.com" + client._connection.get_api_base_url_for_mtls = mock.Mock( + return_value="https://foo.mtls" + ) + download_url = blob._get_download_url(client) + del client._connection.get_api_base_url_for_mtls + expected_url = ( + "https://foo.mtls/download/storage/v1/b/" "buhkit/o/bzzz-fly.txt?alt=media" + ) + self.assertEqual(download_url, expected_url) + def test__get_download_url_on_the_fly_with_generation(self): blob_name = "pretend.txt" bucket = _Bucket(name="fictional") @@ -1959,6 +1977,7 @@ def _do_multipart_success( kms_key_name=None, timeout=None, metadata=None, + mtls=False, ): from six.moves.urllib.parse import urlencode @@ -1977,6 +1996,14 @@ def _do_multipart_success( client = mock.Mock(_http=transport, _connection=_Connection, spec=["_http"]) client._connection.API_BASE_URL = "https://storage.googleapis.com" + + # Mock get_api_base_url_for_mtls function. + mtls_url = "https://foo.mtls" + if mtls: + client._connection.get_api_base_url_for_mtls = mock.Mock( + return_value=mtls_url + ) + data = b"data here hear hier" stream = io.BytesIO(data) content_type = u"application/xml" @@ -2002,6 +2029,10 @@ def _do_multipart_success( **timeout_kwarg ) + # Clean up the get_api_base_url_for_mtls mock. + if mtls: + del client._connection.get_api_base_url_for_mtls + # Check the mocks and the returned value. self.assertIs(response, client._http.request.return_value) if size is None: @@ -2016,6 +2047,8 @@ def _do_multipart_success( upload_url = ( "https://storage.googleapis.com/upload/storage/v1" + bucket.path + "/o" ) + if mtls: + upload_url = mtls_url + "/upload/storage/v1" + bucket.path + "/o" qs_params = [("uploadType", "multipart")] @@ -2064,6 +2097,12 @@ def _do_multipart_success( def test__do_multipart_upload_no_size(self, mock_get_boundary): self._do_multipart_success(mock_get_boundary, predefined_acl="private") + @mock.patch(u"google.resumable_media._upload.get_boundary", return_value=b"==0==") + def test__do_multipart_upload_no_size_mtls(self, mock_get_boundary): + self._do_multipart_success( + mock_get_boundary, predefined_acl="private", mtls=True + ) + @mock.patch(u"google.resumable_media._upload.get_boundary", return_value=b"==0==") def test__do_multipart_upload_with_size(self, mock_get_boundary): self._do_multipart_success(mock_get_boundary, size=10) @@ -2159,6 +2198,7 @@ def _initiate_resumable_helper( kms_key_name=None, timeout=None, metadata=None, + mtls=False, ): from six.moves.urllib.parse import urlencode from google.resumable_media.requests import ResumableUpload @@ -2197,6 +2237,14 @@ def _initiate_resumable_helper( _http=transport, _connection=_Connection, spec=[u"_http"] ) client._connection.API_BASE_URL = "https://storage.googleapis.com" + + # Mock get_api_base_url_for_mtls function. + mtls_url = "https://foo.mtls" + if mtls: + client._connection.get_api_base_url_for_mtls = mock.Mock( + return_value=mtls_url + ) + data = b"hello hallo halo hi-low" stream = io.BytesIO(data) content_type = u"text/plain" @@ -2224,12 +2272,18 @@ def _initiate_resumable_helper( **timeout_kwarg ) + # Clean up the get_api_base_url_for_mtls mock. + if mtls: + del client._connection.get_api_base_url_for_mtls + # Check the returned values. self.assertIsInstance(upload, ResumableUpload) upload_url = ( "https://storage.googleapis.com/upload/storage/v1" + bucket.path + "/o" ) + if mtls: + upload_url = mtls_url + "/upload/storage/v1" + bucket.path + "/o" qs_params = [("uploadType", "resumable")] if user_project is not None: @@ -2322,6 +2376,9 @@ def test__initiate_resumable_upload_with_custom_timeout(self): def test__initiate_resumable_upload_no_size(self): self._initiate_resumable_helper() + def test__initiate_resumable_upload_no_size_mtls(self): + self._initiate_resumable_helper(mtls=True) + def test__initiate_resumable_upload_with_size(self): self._initiate_resumable_helper(size=10000) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index a4c23d7cc..93cfdb8ca 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -89,6 +89,7 @@ def _make_json_response(data, status=http_client.OK, headers=None): def _make_requests_session(responses): session = mock.create_autospec(requests.Session, instance=True) session.request.side_effect = responses + session.is_mtls = False return session @@ -219,6 +220,21 @@ def test_ctor_w_client_info(self): self.assertEqual(list(client._batch_stack), []) self.assertIs(client._connection._client_info, client_info) + def test_ctor_mtls(self): + credentials = _make_credentials() + + client = self._make_one(credentials=credentials) + self.assertEqual(client._connection.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True) + self.assertEqual( + client._connection.API_BASE_URL, "https://storage.googleapis.com" + ) + + client = self._make_one( + credentials=credentials, client_options={"api_endpoint": "http://foo"} + ) + self.assertEqual(client._connection.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False) + self.assertEqual(client._connection.API_BASE_URL, "http://foo") + def test_create_anonymous_client(self): from google.auth.credentials import AnonymousCredentials from google.cloud.storage._http import Connection