Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mtls feature #492

Merged
merged 6 commits into from Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 19 additions & 2 deletions google/cloud/bigquery/_http.py
Expand Up @@ -14,25 +14,42 @@

"""Create / interact with Google BigQuery connections."""

import os
import pkg_resources

from google.cloud import _http

from google.cloud.bigquery import __version__


# TODO: Increase the minimum version of google-cloud-core to 1.6.0
# and remove this logic. See:
# https://github.com/googleapis/python-bigquery/issues/509
if os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true": # pragma: NO COVER
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
release = pkg_resources.get_distribution("google-cloud-core").parsed_version
if release < pkg_resources.parse_version("1.6.0"):
raise ImportError("google-cloud-core >= 1.6.0 is required to use mTLS feature")


class Connection(_http.JSONConnection):
"""A connection to Google BigQuery via the JSON REST API.

Args:
client (google.cloud.bigquery.client.Client): The client that owns the current connection.

client_info (Optional[google.api_core.client_info.ClientInfo]): Instance used to generate user agent.

api_endpoint (str): The api_endpoint to use. If None, the library will decide what endpoint to use.
"""

DEFAULT_API_ENDPOINT = "https://bigquery.googleapis.com"
DEFAULT_API_MTLS_ENDPOINT = "https://bigquery.mtls.googleapis.com"

def __init__(self, client, client_info=None, api_endpoint=DEFAULT_API_ENDPOINT):
def __init__(self, client, client_info=None, api_endpoint=None):
super(Connection, self).__init__(client, client_info)
self.API_BASE_URL = api_endpoint
self.API_BASE_URL = api_endpoint or self.DEFAULT_API_ENDPOINT
self.API_BASE_MTLS_URL = self.DEFAULT_API_MTLS_ENDPOINT
self.ALLOW_AUTO_SWITCH_TO_MTLS_URL = api_endpoint is None
self._client_info.gapic_version = __version__
self._client_info.client_library_version = __version__

Expand Down
25 changes: 19 additions & 6 deletions google/cloud/bigquery/client.py
Expand Up @@ -78,10 +78,7 @@
_DEFAULT_CHUNKSIZE = 1048576 # 1024 * 1024 B = 1 MB
_MAX_MULTIPART_SIZE = 5 * 1024 * 1024
_DEFAULT_NUM_RETRIES = 6
_BASE_UPLOAD_TEMPLATE = (
"https://bigquery.googleapis.com/upload/bigquery/v2/projects/"
"{project}/jobs?uploadType="
)
_BASE_UPLOAD_TEMPLATE = "{host}/upload/bigquery/v2/projects/{project}/jobs?uploadType="
_MULTIPART_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "multipart"
_RESUMABLE_URL_TEMPLATE = _BASE_UPLOAD_TEMPLATE + "resumable"
_GENERIC_CONTENT_TYPE = "*/*"
Expand Down Expand Up @@ -2547,7 +2544,15 @@ def _initiate_resumable_upload(

if project is None:
project = self.project
upload_url = _RESUMABLE_URL_TEMPLATE.format(project=project)
# TODO: Increase the minimum version of google-cloud-core to 1.6.0
# and remove this logic. See:
# https://github.com/googleapis/python-bigquery/issues/509
hostname = (
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
self._connection.API_BASE_URL
if not hasattr(self._connection, "get_api_base_url_for_mtls")
else self._connection.get_api_base_url_for_mtls()
)
upload_url = _RESUMABLE_URL_TEMPLATE.format(host=hostname, project=project)

# TODO: modify ResumableUpload to take a retry.Retry object
# that it can use for the initial RPC.
Expand Down Expand Up @@ -2616,7 +2621,15 @@ def _do_multipart_upload(
if project is None:
project = self.project

upload_url = _MULTIPART_URL_TEMPLATE.format(project=project)
# TODO: Increase the minimum version of google-cloud-core to 1.6.0
# and remove this logic. See:
# https://github.com/googleapis/python-bigquery/issues/509
hostname = (
arithmetic1728 marked this conversation as resolved.
Show resolved Hide resolved
self._connection.API_BASE_URL
if not hasattr(self._connection, "get_api_base_url_for_mtls")
else self._connection.get_api_base_url_for_mtls()
)
upload_url = _MULTIPART_URL_TEMPLATE.format(host=hostname, project=project)
upload = MultipartUpload(upload_url, headers=headers)

if num_retries is not None:
Expand Down
6 changes: 6 additions & 0 deletions tests/system/test_client.py
Expand Up @@ -28,6 +28,7 @@
import uuid

import psutil
import pytest
import pytz
import pkg_resources

Expand Down Expand Up @@ -132,6 +133,8 @@
else:
PYARROW_INSTALLED_VERSION = None

MTLS_TESTING = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true"


def _has_rows(result):
return len(result) > 0
Expand Down Expand Up @@ -2651,6 +2654,9 @@ def test_insert_rows_nested_nested_dictionary(self):
expected_rows = [("Some value", record)]
self.assertEqual(row_tuples, expected_rows)

@pytest.mark.skipif(
MTLS_TESTING, reason="mTLS testing has no permission to the max-value.js file"
)
def test_create_routine(self):
routine_name = "test_routine"
dataset = self.temp_dataset(_make_dataset_id("create_routine"))
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/helpers.py
Expand Up @@ -21,6 +21,8 @@ def make_connection(*responses):
mock_conn = mock.create_autospec(google.cloud.bigquery._http.Connection)
mock_conn.user_agent = "testing 1.2.3"
mock_conn.api_request.side_effect = list(responses) + [NotFound("miss")]
mock_conn.API_BASE_URL = "https://bigquery.googleapis.com"
mock_conn.get_api_base_url_for_mtls = mock.Mock(return_value=mock_conn.API_BASE_URL)
return mock_conn


Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test__http.py
Expand Up @@ -32,6 +32,9 @@ def _get_target_class():
return Connection

def _make_one(self, *args, **kw):
if "api_endpoint" not in kw:
kw["api_endpoint"] = "https://bigquery.googleapis.com"

return self._get_target_class()(*args, **kw)

def test_build_api_url_no_extra_query_params(self):
Expand Down Expand Up @@ -138,3 +141,14 @@ def test_extra_headers_replace(self):
url=expected_uri,
timeout=self._get_default_timeout(),
)

def test_ctor_mtls(self):
conn = self._make_one(object(), api_endpoint=None)
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, True)
self.assertEqual(conn.API_BASE_URL, "https://bigquery.googleapis.com")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com")

conn = self._make_one(object(), api_endpoint="http://foo")
self.assertEqual(conn.ALLOW_AUTO_SWITCH_TO_MTLS_URL, False)
self.assertEqual(conn.API_BASE_URL, "http://foo")
self.assertEqual(conn.API_BASE_MTLS_URL, "https://bigquery.mtls.googleapis.com")
23 changes: 19 additions & 4 deletions tests/unit/test_client.py
Expand Up @@ -2057,6 +2057,7 @@ def test_get_table_sets_user_agent(self):
url=mock.ANY, method=mock.ANY, headers=mock.ANY, data=mock.ANY
)
http.reset_mock()
http.is_mtls = False
mock_response.status_code = 200
mock_response.json.return_value = self._make_table_resource()
user_agent_override = client_info.ClientInfo(user_agent="my-application/1.2.3")
Expand Down Expand Up @@ -4425,7 +4426,7 @@ def _mock_transport(self, status_code, headers, content=b""):
fake_transport.request.return_value = fake_response
return fake_transport

def _initiate_resumable_upload_helper(self, num_retries=None):
def _initiate_resumable_upload_helper(self, num_retries=None, mtls=False):
from google.resumable_media.requests import ResumableUpload
from google.cloud.bigquery.client import _DEFAULT_CHUNKSIZE
from google.cloud.bigquery.client import _GENERIC_CONTENT_TYPE
Expand All @@ -4440,6 +4441,8 @@ def _initiate_resumable_upload_helper(self, num_retries=None):
fake_transport = self._mock_transport(http.client.OK, response_headers)
client = self._make_one(project=self.PROJECT, _http=fake_transport)
conn = client._connection = make_connection()
if mtls:
conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls")

# Create some mock arguments and call the method under test.
data = b"goodbye gudbi gootbee"
Expand All @@ -4454,8 +4457,10 @@ def _initiate_resumable_upload_helper(self, num_retries=None):

# Check the returned values.
self.assertIsInstance(upload, ResumableUpload)

host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com"
upload_url = (
f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{self.PROJECT}"
f"{host_name}/upload/bigquery/v2/projects/{self.PROJECT}"
"/jobs?uploadType=resumable"
)
self.assertEqual(upload.upload_url, upload_url)
Expand Down Expand Up @@ -4494,11 +4499,14 @@ def _initiate_resumable_upload_helper(self, num_retries=None):
def test__initiate_resumable_upload(self):
self._initiate_resumable_upload_helper()

def test__initiate_resumable_upload_mtls(self):
self._initiate_resumable_upload_helper(mtls=True)

def test__initiate_resumable_upload_with_retry(self):
self._initiate_resumable_upload_helper(num_retries=11)

def _do_multipart_upload_success_helper(
self, get_boundary, num_retries=None, project=None
self, get_boundary, num_retries=None, project=None, mtls=False
):
from google.cloud.bigquery.client import _get_upload_headers
from google.cloud.bigquery.job import LoadJob
Expand All @@ -4508,6 +4516,8 @@ def _do_multipart_upload_success_helper(
fake_transport = self._mock_transport(http.client.OK, {})
client = self._make_one(project=self.PROJECT, _http=fake_transport)
conn = client._connection = make_connection()
if mtls:
conn.get_api_base_url_for_mtls = mock.Mock(return_value="https://foo.mtls")

if project is None:
project = self.PROJECT
Expand All @@ -4530,8 +4540,9 @@ def _do_multipart_upload_success_helper(
self.assertEqual(stream.tell(), size)
get_boundary.assert_called_once_with()

host_name = "https://foo.mtls" if mtls else "https://bigquery.googleapis.com"
upload_url = (
f"https://bigquery.googleapis.com/upload/bigquery/v2/projects/{project}"
f"{host_name}/upload/bigquery/v2/projects/{project}"
"/jobs?uploadType=multipart"
)
payload = (
Expand All @@ -4556,6 +4567,10 @@ def _do_multipart_upload_success_helper(
def test__do_multipart_upload(self, get_boundary):
self._do_multipart_upload_success_helper(get_boundary)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_mtls(self, get_boundary):
self._do_multipart_upload_success_helper(get_boundary, mtls=True)

@mock.patch("google.resumable_media._upload.get_boundary", return_value=b"==0==")
def test__do_multipart_upload_with_retry(self, get_boundary):
self._do_multipart_upload_success_helper(get_boundary, num_retries=8)
Expand Down