From 1823cadee3acf95c516d0479400e4175349ea199 Mon Sep 17 00:00:00 2001 From: arithmetic1728 <58957152+arithmetic1728@users.noreply.github.com> Date: Fri, 5 Feb 2021 07:53:15 -0800 Subject: [PATCH] feat: add mtls support to client (#492) * feat: add mtls feature --- google/cloud/bigquery/_http.py | 21 +++++++++++++++++++-- google/cloud/bigquery/client.py | 25 +++++++++++++++++++------ tests/system/test_client.py | 6 ++++++ tests/unit/helpers.py | 2 ++ tests/unit/test__http.py | 14 ++++++++++++++ tests/unit/test_client.py | 23 +++++++++++++++++++---- 6 files changed, 79 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigquery/_http.py b/google/cloud/bigquery/_http.py index 8ee633e64..ede26cc70 100644 --- a/google/cloud/bigquery/_http.py +++ b/google/cloud/bigquery/_http.py @@ -14,11 +14,23 @@ """Create / interact with Google BigQuery connections.""" +import os +import pkg_resources + from google.cloud import _http from google.cloud.bigquery import __version__ +# TODO: Increase the minimum version of google-cloud-core to 1.6.0 +# and remove this logic. See: +# https://github.com/googleapis/python-bigquery/issues/509 +if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true": # pragma: NO COVER + release = pkg_resources.get_distribution("google-cloud-core").parsed_version + if release < pkg_resources.parse_version("1.6.0"): + raise ImportError("google-cloud-core >= 1.6.0 is required to use mTLS feature") + + class Connection(_http.JSONConnection): """A connection to Google BigQuery via the JSON REST API. @@ -26,13 +38,18 @@ class Connection(_http.JSONConnection): client (google.cloud.bigquery.client.Client): The client that owns the current connection. client_info (Optional[google.api_core.client_info.ClientInfo]): Instance used to generate user agent. + + api_endpoint (str): The api_endpoint to use. If None, the library will decide what endpoint to use. """ DEFAULT_API_ENDPOINT = "https://bigquery.googleapis.com" + DEFAULT_API_MTLS_ENDPOINT = "https://bigquery.mtls.googleapis.com" - def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT): + def __init__(self, client, client_info=None, api_endpoint=None): super(Connection, self).__init__(client, client_info) - self.API_BASE_URL = api_endpoint + self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT + self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT + self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None self._client_info.gapic_version = __version__ self._client_info.client_library_version = __version__ diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index b270075a9..f8c0d7c93 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -78,10 +78,7 @@ _DEFAULT_CHUNKSIZE = 1048576 # 1024 * 1024 B = 1 MB _MAX_MULTIPART_SIZE = 5 * 1024 * 1024 _DEFAULT_NUM_RETRIES = 6 -_BASE_UPLOAD_TEMPLATE = ( - "https://bigquery.googleapis.com/upload/bigquery/v2/projects/" - "{project}/jobs?uploadType=" -) +_BASE_UPLOAD_TEMPLATE = "{host}/upload/bigquery/v2/projects/{project}/jobs?uploadType=" _MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "multipart" _RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "resumable" _GENERIC_CONTENT_TYPE = "*/*" @@ -2547,7 +2544,15 @@ def _initiate_resumable_upload( if project is None: project = self.project - upload_url = _RESUMABLE_URL_TEMPLATE.format(project=project) + # TODO: Increase the minimum version of google-cloud-core to 1.6.0 + # and remove this logic. See: + # https://github.com/googleapis/python-bigquery/issues/509 + hostname = ( + self._connection.API_BASE_URL + if not hasattr(self._connection, "get_api_base_url_for_mtls") + else self._connection.get_api_base_url_for_mtls() + ) + upload_url = _RESUMABLE_URL_TEMPLATE.format(host=hostname, project=project) # TODO: modify ResumableUpload to take a retry.Retry object # that it can use for the initial RPC. @@ -2616,7 +2621,15 @@ def _do_multipart_upload( if project is None: project = self.project - upload_url = _MULTIPART_URL_TEMPLATE.format(project=project) + # TODO: Increase the minimum version of google-cloud-core to 1.6.0 + # and remove this logic. See: + # https://github.com/googleapis/python-bigquery/issues/509 + hostname = ( + self._connection.API_BASE_URL + if not hasattr(self._connection, "get_api_base_url_for_mtls") + else self._connection.get_api_base_url_for_mtls() + ) + upload_url = _MULTIPART_URL_TEMPLATE.format(host=hostname, project=project) upload = MultipartUpload(upload_url, headers=headers) if num_retries is not None: diff --git a/tests/system/test_client.py b/tests/system/test_client.py index aa1a03160..85c044bad 100644 --- a/tests/system/test_client.py +++ b/tests/system/test_client.py @@ -28,6 +28,7 @@ import uuid import psutil +import pytest import pytz import pkg_resources @@ -132,6 +133,8 @@ else: PYARROW_INSTALLED_VERSION = None +MTLS_TESTING = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true" + def _has_rows(result): return len(result) > 0 @@ -2651,6 +2654,9 @@ def test_insert_rows_nested_nested_dictionary(self): expected_rows = [("Some value", record)] self.assertEqual(row_tuples, expected_rows) + @pytest.mark.skipif( + MTLS_TESTING, reason="mTLS testing has no permission to the max-value.js file" + ) def test_create_routine(self): routine_name = "test_routine" dataset = self.temp_dataset(_make_dataset_id("create_routine")) diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index eea345e89..b51b0bbb7 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -21,6 +21,8 @@ def make_connection(*responses): mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection) mock_conn.user_agent = "testing 1.2.3" mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")] + mock_conn.API_BASE_URL = "https://bigquery.googleapis.com" + mock_conn.get_api_base_url_for_mtls = mock.Mock(return_value=mock_conn.API_BASE_URL) return mock_conn diff --git a/tests/unit/test__http.py b/tests/unit/test__http.py index 78e59cb30..09f6d29d7 100644 --- a/tests/unit/test__http.py +++ b/tests/unit/test__http.py @@ -32,6 +32,9 @@ def _get_target_class(): return Connection def _make_one(self, *args, **kw): + if "api_endpoint" not in kw: + kw["api_endpoint"] = "https://bigquery.googleapis.com" + return self._get_target_class()(*args, **kw) def test_build_api_url_no_extra_query_params(self): @@ -138,3 +141,14 @@ def test_extra_headers_replace(self): url=expected_uri, timeout=self._get_default_timeout(), ) + + def test_ctor_mtls(self): + conn = self._make_one(object(), api_endpoint=None) + self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True) + self.assertEqual(conn.API_BASE_URL, "https://bigquery.googleapis.com") + self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com") + + conn = self._make_one(object(), api_endpoint="http://foo") + self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False) + self.assertEqual(conn.API_BASE_URL, "http://foo") + self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com") diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 625256e6e..66add9c0a 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -2057,6 +2057,7 @@ def test_get_table_sets_user_agent(self): url=mock.ANY, method=mock.ANY, headers=mock.ANY, data=mock.ANY ) http.reset_mock() + http.is_mtls = False mock_response.status_code = 200 mock_response.json.return_value = self._make_table_resource() user_agent_override = client_info.ClientInfo(user_agent="my-application/1.2.3") @@ -4425,7 +4426,7 @@ def _mock_transport(self, status_code, headers, content=b""): fake_transport.request.return_value = fake_response return fake_transport - def _initiate_resumable_upload_helper(self, num_retries=None): + def _initiate_resumable_upload_helper(self, num_retries=None, mtls=False): from google.resumable_media.requests import ResumableUpload from google.cloud.bigquery.client import _DEFAULT_CHUNKSIZE from google.cloud.bigquery.client import _GENERIC_CONTENT_TYPE @@ -4440,6 +4441,8 @@ def _initiate_resumable_upload_helper(self, num_retries=None): fake_transport = self._mock_transport(http.client.OK, response_headers) client = self._make_one(project=self.PROJECT, _http=fake_transport) conn = client._connection = make_connection() + if mtls: + conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls") # Create some mock arguments and call the method under test. data = b"goodbye gudbi gootbee" @@ -4454,8 +4457,10 @@ def _initiate_resumable_upload_helper(self, num_retries=None): # Check the returned values. self.assertIsInstance(upload, ResumableUpload) + + host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com" upload_url = ( - f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{self.PROJECT}" + f"{host_name}/upload/bigquery/v2/projects/{self.PROJECT}" "/jobs?uploadType=resumable" ) self.assertEqual(upload.upload_url, upload_url) @@ -4494,11 +4499,14 @@ def _initiate_resumable_upload_helper(self, num_retries=None): def test__initiate_resumable_upload(self): self._initiate_resumable_upload_helper() + def test__initiate_resumable_upload_mtls(self): + self._initiate_resumable_upload_helper(mtls=True) + def test__initiate_resumable_upload_with_retry(self): self._initiate_resumable_upload_helper(num_retries=11) def _do_multipart_upload_success_helper( - self, get_boundary, num_retries=None, project=None + self, get_boundary, num_retries=None, project=None, mtls=False ): from google.cloud.bigquery.client import _get_upload_headers from google.cloud.bigquery.job import LoadJob @@ -4508,6 +4516,8 @@ def _do_multipart_upload_success_helper( fake_transport = self._mock_transport(http.client.OK, {}) client = self._make_one(project=self.PROJECT, _http=fake_transport) conn = client._connection = make_connection() + if mtls: + conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls") if project is None: project = self.PROJECT @@ -4530,8 +4540,9 @@ def _do_multipart_upload_success_helper( self.assertEqual(stream.tell(), size) get_boundary.assert_called_once_with() + host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com" upload_url = ( - f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{project}" + f"{host_name}/upload/bigquery/v2/projects/{project}" "/jobs?uploadType=multipart" ) payload = ( @@ -4556,6 +4567,10 @@ def _do_multipart_upload_success_helper( def test__do_multipart_upload(self, get_boundary): self._do_multipart_upload_success_helper(get_boundary) + @mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==") + def test__do_multipart_upload_mtls(self, get_boundary): + self._do_multipart_upload_success_helper(get_boundary, mtls=True) + @mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==") def test__do_multipart_upload_with_retry(self, get_boundary): self._do_multipart_upload_success_helper(get_boundary, num_retries=8)