From 91fc6d9870a36308b15a827ed6a691e5b4669b62 Mon Sep 17 00:00:00 2001 From: cojenco <59401799+cojenco@users.noreply.github.com> Date: Tue, 9 Mar 2021 14:39:24 -0800 Subject: [PATCH] fix: update batch connection to request api endpoint info from client (#392) * bug: let Batch Connection request api endpoint info from Client * add unit test coverage Co-authored-by: Cathy Ouyang --- google/cloud/storage/batch.py | 6 +++++- tests/unit/test_batch.py | 40 +++++++++++++++++++++++++++++------ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/google/cloud/storage/batch.py b/google/cloud/storage/batch.py index abfc88412..d40fdc6f5 100644 --- a/google/cloud/storage/batch.py +++ b/google/cloud/storage/batch.py @@ -147,7 +147,11 @@ class Batch(Connection): _MAX_BATCH_SIZE = 1000 def __init__(self, client): - super(Batch, self).__init__(client) + api_endpoint = client._connection.API_BASE_URL + client_info = client._connection._client_info + super(Batch, self).__init__( + client, client_info=client_info, api_endpoint=api_endpoint + ) self._requests = [] self._target_objects = [] diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index ec8fe75de..d43f27e8e 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -136,7 +136,8 @@ def test__make_request_GET_normal(self): url = "http://example.com/api" http = _make_requests_session([]) connection = _Connection(http=http) - batch = self._make_one(connection) + client = _Client(connection) + batch = self._make_one(client) target = _MockObject() response = batch._make_request("GET", url, target_object=target) @@ -164,7 +165,8 @@ def test__make_request_POST_normal(self): url = "http://example.com/api" http = _make_requests_session([]) connection = _Connection(http=http) - batch = self._make_one(connection) + client = _Client(connection) + batch = self._make_one(client) data = {"foo": 1} target = _MockObject() @@ -191,7 +193,8 @@ def test__make_request_PATCH_normal(self): url = "http://example.com/api" http = _make_requests_session([]) connection = _Connection(http=http) - batch = self._make_one(connection) + client = _Client(connection) + batch = self._make_one(client) data = {"foo": 1} target = _MockObject() @@ -218,7 +221,8 @@ def test__make_request_DELETE_normal(self): url = "http://example.com/api" http = _make_requests_session([]) connection = _Connection(http=http) - batch = self._make_one(connection) + client = _Client(connection) + batch = self._make_one(client) target = _MockObject() response = batch._make_request("DELETE", url, target_object=target) @@ -243,7 +247,8 @@ def test__make_request_POST_too_many_requests(self): url = "http://example.com/api" http = _make_requests_session([]) connection = _Connection(http=http) - batch = self._make_one(connection) + client = _Client(connection) + batch = self._make_one(client) batch._MAX_BATCH_SIZE = 1 batch._requests.append(("POST", url, {}, {"bar": 2})) @@ -254,7 +259,8 @@ def test__make_request_POST_too_many_requests(self): def test_finish_empty(self): http = _make_requests_session([]) connection = _Connection(http=http) - batch = self._make_one(connection) + client = _Client(connection) + batch = self._make_one(client) with self.assertRaises(ValueError): batch.finish() @@ -518,6 +524,25 @@ def test_as_context_mgr_w_error(self): self.assertIsInstance(target2._properties, _FutureDict) self.assertIsInstance(target3._properties, _FutureDict) + def test_respect_client_existing_connection(self): + client_endpoint = "http://localhost:9023" + http = _make_requests_session([]) + connection = _Connection(http=http, api_endpoint=client_endpoint) + client = _Client(connection) + batch = self._make_one(client) + self.assertEqual(batch.API_BASE_URL, client_endpoint) + self.assertEqual(batch._client._connection.API_BASE_URL, client_endpoint) + + def test_use_default_api_without_existing_connection(self): + default_api_endpoint = "https://storage.googleapis.com" + http = _make_requests_session([]) + connection = _Connection(http=http) + client = _Client(connection) + batch = self._make_one(client) + self.assertEqual(batch.API_BASE_URL, default_api_endpoint) + self.assertIsNone(batch._client._connection.API_BASE_URL) + self.assertIsNone(batch._client._connection._client_info) + class Test__unpack_batch_response(unittest.TestCase): def _call_fut(self, headers, content): @@ -633,6 +658,8 @@ class _Connection(object): def __init__(self, **kw): self.__dict__.update(kw) + self._client_info = kw.get("client_info", None) + self.API_BASE_URL = kw.get("api_endpoint", None) def _make_request(self, method, url, data=None, headers=None, timeout=None): return self.http.request( @@ -647,3 +674,4 @@ class _MockObject(object): class _Client(object): def __init__(self, connection): self._base_connection = connection + self._connection = connection