Skip to content

Commit

Permalink
fix: update batch connection to request api endpoint info from client (
Browse files Browse the repository at this point in the history
…#392)

* bug: let Batch Connection request api endpoint info from Client

* add unit test coverage

Co-authored-by: Cathy Ouyang <cathyo@google.com>
  • Loading branch information
cojenco and cojenco committed Mar 9, 2021
1 parent d346c94 commit 91fc6d9
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
6 changes: 5 additions & 1 deletion google/cloud/storage/batch.py
Expand Up @@ -147,7 +147,11 @@ class Batch(Connection):
_MAX_BATCH_SIZE = 1000

def __init__(self, client):
super(Batch, self).__init__(client)
api_endpoint = client._connection.API_BASE_URL
client_info = client._connection._client_info
super(Batch, self).__init__(
client, client_info=client_info, api_endpoint=api_endpoint
)
self._requests = []
self._target_objects = []

Expand Down
40 changes: 34 additions & 6 deletions tests/unit/test_batch.py
Expand Up @@ -136,7 +136,8 @@ def test__make_request_GET_normal(self):
url = "http://example.com/api"
http = _make_requests_session([])
connection = _Connection(http=http)
batch = self._make_one(connection)
client = _Client(connection)
batch = self._make_one(client)
target = _MockObject()

response = batch._make_request("GET", url, target_object=target)
Expand Down Expand Up @@ -164,7 +165,8 @@ def test__make_request_POST_normal(self):
url = "http://example.com/api"
http = _make_requests_session([])
connection = _Connection(http=http)
batch = self._make_one(connection)
client = _Client(connection)
batch = self._make_one(client)
data = {"foo": 1}
target = _MockObject()

Expand All @@ -191,7 +193,8 @@ def test__make_request_PATCH_normal(self):
url = "http://example.com/api"
http = _make_requests_session([])
connection = _Connection(http=http)
batch = self._make_one(connection)
client = _Client(connection)
batch = self._make_one(client)
data = {"foo": 1}
target = _MockObject()

Expand All @@ -218,7 +221,8 @@ def test__make_request_DELETE_normal(self):
url = "http://example.com/api"
http = _make_requests_session([])
connection = _Connection(http=http)
batch = self._make_one(connection)
client = _Client(connection)
batch = self._make_one(client)
target = _MockObject()

response = batch._make_request("DELETE", url, target_object=target)
Expand All @@ -243,7 +247,8 @@ def test__make_request_POST_too_many_requests(self):
url = "http://example.com/api"
http = _make_requests_session([])
connection = _Connection(http=http)
batch = self._make_one(connection)
client = _Client(connection)
batch = self._make_one(client)

batch._MAX_BATCH_SIZE = 1
batch._requests.append(("POST", url, {}, {"bar": 2}))
Expand All @@ -254,7 +259,8 @@ def test__make_request_POST_too_many_requests(self):
def test_finish_empty(self):
http = _make_requests_session([])
connection = _Connection(http=http)
batch = self._make_one(connection)
client = _Client(connection)
batch = self._make_one(client)

with self.assertRaises(ValueError):
batch.finish()
Expand Down Expand Up @@ -518,6 +524,25 @@ def test_as_context_mgr_w_error(self):
self.assertIsInstance(target2._properties, _FutureDict)
self.assertIsInstance(target3._properties, _FutureDict)

def test_respect_client_existing_connection(self):
client_endpoint = "http://localhost:9023"
http = _make_requests_session([])
connection = _Connection(http=http, api_endpoint=client_endpoint)
client = _Client(connection)
batch = self._make_one(client)
self.assertEqual(batch.API_BASE_URL, client_endpoint)
self.assertEqual(batch._client._connection.API_BASE_URL, client_endpoint)

def test_use_default_api_without_existing_connection(self):
default_api_endpoint = "https://storage.googleapis.com"
http = _make_requests_session([])
connection = _Connection(http=http)
client = _Client(connection)
batch = self._make_one(client)
self.assertEqual(batch.API_BASE_URL, default_api_endpoint)
self.assertIsNone(batch._client._connection.API_BASE_URL)
self.assertIsNone(batch._client._connection._client_info)


class Test__unpack_batch_response(unittest.TestCase):
def _call_fut(self, headers, content):
Expand Down Expand Up @@ -633,6 +658,8 @@ class _Connection(object):

def __init__(self, **kw):
self.__dict__.update(kw)
self._client_info = kw.get("client_info", None)
self.API_BASE_URL = kw.get("api_endpoint", None)

def _make_request(self, method, url, data=None, headers=None, timeout=None):
return self.http.request(
Expand All @@ -647,3 +674,4 @@ class _MockObject(object):
class _Client(object):
def __init__(self, connection):
self._base_connection = connection
self._connection = connection

0 comments on commit 91fc6d9

Please sign in to comment.