diff --git a/google/cloud/storage/_http.py b/google/cloud/storage/_http.py index 6e175196c..e41b63b6b 100644 --- a/google/cloud/storage/_http.py +++ b/google/cloud/storage/_http.py @@ -29,13 +29,19 @@ class Connection(_http.JSONConnection): :type client_info: :class:`~google.api_core.client_info.ClientInfo` :param client_info: (Optional) instance used to generate user agent. + + :type api_endpoint: str or None + :param api_endpoint: (Optional) api endpoint to use. """ DEFAULT_API_ENDPOINT = "https://storage.googleapis.com" + DEFAULT_API_MTLS_ENDPOINT = "https://storage.mtls.googleapis.com" - def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT): + def __init__(self, client, client_info=None, api_endpoint=None): super(Connection, self).__init__(client, client_info) - self.API_BASE_URL = api_endpoint + self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT + self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT + self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None self._client_info.client_library_version = __version__ # TODO: When metrics all use gccl, this should be removed #9552 diff --git a/google/cloud/storage/blob.py b/google/cloud/storage/blob.py index 8564f8e0d..e0a6bf864 100644 --- a/google/cloud/storage/blob.py +++ b/google/cloud/storage/blob.py @@ -829,9 +829,12 @@ def _get_download_url( """ name_value_pairs = [] if self.media_link is None: - base_url = _DOWNLOAD_URL_TEMPLATE.format( - hostname=client._connection.API_BASE_URL, path=self.path + hostname = ( + client._connection.API_BASE_URL + if not hasattr(client._connection, "get_api_base_url_for_mtls") + else client._connection.get_api_base_url_for_mtls() ) + base_url = _DOWNLOAD_URL_TEMPLATE.format(hostname=hostname, path=self.path) if self.generation is not None: name_value_pairs.append(("generation", "{:d}".format(self.generation))) else: @@ -1683,8 +1686,13 @@ def _do_multipart_upload( info = self._get_upload_arguments(content_type) headers, object_metadata, content_type = info + hostname = ( + client._connection.API_BASE_URL + if not hasattr(client._connection, "get_api_base_url_for_mtls") + else client._connection.get_api_base_url_for_mtls() + ) base_url = _MULTIPART_URL_TEMPLATE.format( - hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path + hostname=hostname, bucket_path=self.bucket.path ) name_value_pairs = [] @@ -1864,8 +1872,13 @@ def _initiate_resumable_upload( if extra_headers is not None: headers.update(extra_headers) + hostname = ( + client._connection.API_BASE_URL + if not hasattr(client._connection, "get_api_base_url_for_mtls") + else client._connection.get_api_base_url_for_mtls() + ) base_url = _RESUMABLE_URL_TEMPLATE.format( - hostname=client._connection.API_BASE_URL, bucket_path=self.bucket.path + hostname=hostname, bucket_path=self.bucket.path ) name_value_pairs = [] diff --git a/google/cloud/storage/client.py b/google/cloud/storage/client.py index 8812dc32e..e6a9fe276 100644 --- a/google/cloud/storage/client.py +++ b/google/cloud/storage/client.py @@ -32,6 +32,7 @@ from google.cloud.client import ClientWithProject from google.cloud.exceptions import NotFound from google.cloud.storage._helpers import _get_storage_host +from google.cloud.storage._helpers import _DEFAULT_STORAGE_HOST from google.cloud.storage._helpers import _bucket_bound_hostname_url from google.cloud.storage._http import Connection from google.cloud.storage._signing import ( @@ -127,7 +128,13 @@ def __init__( kw_args = {"client_info": client_info} - kw_args["api_endpoint"] = _get_storage_host() + # `api_endpoint` should be only set by the user via `client_options`, + # or if the _get_storage_host() returns a non-default value. + # `api_endpoint` plays an important role for mTLS, if it is not set, + # then mTLS logic will be applied to decide which endpoint will be used. + storage_host = _get_storage_host() + if storage_host != _DEFAULT_STORAGE_HOST: + kw_args["api_endpoint"] = storage_host if client_options: if type(client_options) == dict: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 1ff17a61f..3d8358356 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -81,6 +81,7 @@ class Config(object): CLIENT = None TEST_BUCKET = None + TESTING_MTLS = False def setUpModule(): @@ -91,6 +92,10 @@ def setUpModule(): Config.TEST_BUCKET = Config.CLIENT.bucket(bucket_name) Config.TEST_BUCKET.versioning_enabled = True retry_429_503(Config.TEST_BUCKET.create)() + # mTLS testing uses the system test as well. For mTLS testing, + # GOOGLE_API_USE_CLIENT_CERTIFICATE env var will be set to "true" + # explicitly. + Config.TESTING_MTLS = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true" def tearDownModule(): @@ -101,6 +106,15 @@ def tearDownModule(): class TestClient(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(TestClient, cls).setUpClass() + if ( + type(Config.CLIENT._credentials) + is not google.oauth2.service_account.Credentials + ): + cls.skipTest(cls, reason="These tests require a service account credential") + def setUp(self): self.case_hmac_keys_to_delete = [] @@ -563,6 +577,15 @@ def tearDown(self): class TestStorageWriteFiles(TestStorageFiles): ENCRYPTION_KEY = "b23ff11bba187db8c37077e6af3b25b8" + @classmethod + def setUpClass(cls): + super(TestStorageWriteFiles, cls).setUpClass() + if ( + type(Config.CLIENT._credentials) + is not google.oauth2.service_account.Credentials + ): + cls.skipTest(cls, reason="These tests require a service account credential") + def test_large_file_write_from_stream(self): blob = self.bucket.blob("LargeFile") @@ -1272,11 +1295,14 @@ class TestStorageSignURLs(unittest.TestCase): @classmethod def setUpClass(cls): + super(TestStorageSignURLs, cls).setUpClass() if ( type(Config.CLIENT._credentials) is not google.oauth2.service_account.Credentials ): - cls.skipTest("Signing tests requires a service account credential") + cls.skipTest( + cls, reason="Signing tests requires a service account credential" + ) bucket_name = "gcp-signing" + unique_resource_id() cls.bucket = retry_429_503(Config.CLIENT.create_bucket)(bucket_name) @@ -1837,6 +1863,12 @@ class TestStorageNotificationCRUD(unittest.TestCase): CUSTOM_ATTRIBUTES = {"attr1": "value1", "attr2": "value2"} BLOB_NAME_PREFIX = "blob-name-prefix/" + @classmethod + def setUpClass(cls): + super(TestStorageNotificationCRUD, cls).setUpClass() + if Config.TESTING_MTLS: + cls.skipTest(cls, reason="Skip pubsub tests for mTLS testing") + @property def topic_path(self): return "projects/{}/topics/{}".format(Config.CLIENT.project, self.TOPIC_NAME) @@ -2000,6 +2032,9 @@ def _kms_key_name(self, key_name=None): @classmethod def setUpClass(cls): super(TestKMSIntegration, cls).setUpClass() + if Config.TESTING_MTLS: + cls.skipTest(cls, reason="Skip kms tests for mTLS testing") + _empty_bucket(Config.CLIENT, cls.bucket) def setUp(self): @@ -2453,6 +2488,15 @@ def test_ubla_set_unset_preserves_acls(self): class TestV4POSTPolicies(unittest.TestCase): + @classmethod + def setUpClass(cls): + super(TestV4POSTPolicies, cls).setUpClass() + if ( + type(Config.CLIENT._credentials) + is not google.oauth2.service_account.Credentials + ): + cls.skipTest(cls, reason="These tests require a service account credential") + def setUp(self): self.case_buckets_to_delete = []