Skip to content

Commit

Permalink
feat: add retry/timeout to 'query.CollectionGroup.get_partitions'
Browse files Browse the repository at this point in the history
Toward #221
  • Loading branch information
tseaver committed Oct 13, 2020
1 parent e6ad4a1 commit 6e806b0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
17 changes: 16 additions & 1 deletion google/cloud/firestore_v1/query.py
Expand Up @@ -302,7 +302,9 @@ def __init__(
all_descendants=all_descendants,
)

def get_partitions(self, partition_count) -> Generator[QueryPartition, None, None]:
def get_partitions(
self, partition_count, retry: retries.Retry = None, timeout: float = None
) -> Generator[QueryPartition, None, None]:
"""Partition a query for parallelization.
Partitions a query by returning partition cursors that can be used to run the
Expand All @@ -313,6 +315,9 @@ def get_partitions(self, partition_count) -> Generator[QueryPartition, None, Non
partition_count (int): The desired maximum number of partition points. The
number must be strictly positive. The actual number of partitions
returned may be fewer.
retry (google.api_core.retry.Retry): Designation of what errors, if any,
should be retried.
timeout (float): The timeout for this request.
"""
self._validate_partition_query()
query = Query(
Expand All @@ -324,13 +329,23 @@ def get_partitions(self, partition_count) -> Generator[QueryPartition, None, Non
)

parent_path, expected_prefix = self._parent._parent_info()

kwargs = {}

if retry is not None:
kwargs["retry"] = retry

if timeout is not None:
kwargs["timeout"] = timeout

pager = self._client._firestore_api.partition_query(
request={
"parent": parent_path,
"structured_query": query._to_protobuf(),
"partition_count": partition_count,
},
metadata=self._client._rpc_metadata,
**kwargs,
)

start_at = None
Expand Down
24 changes: 22 additions & 2 deletions tests/unit/v1/test_query.py
Expand Up @@ -500,7 +500,7 @@ def test_constructor_all_descendents_is_false(self):
with pytest.raises(ValueError):
self._make_one(mock.sentinel.parent, all_descendants=False)

def test_get_partitions(self):
def _get_partitions_helper(self, retry=None, timeout=None):
# Create a minimal fake GAPIC.
firestore_api = mock.Mock(spec=["partition_query"])

Expand All @@ -522,7 +522,16 @@ def test_get_partitions(self):

# Execute the query and check the response.
query = self._make_one(parent)
get_response = query.get_partitions(2)

kwargs = {}

if retry is not None:
kwargs["retry"] = retry

if timeout is not None:
kwargs["timeout"] = timeout

get_response = query.get_partitions(2, **kwargs)
self.assertIsInstance(get_response, types.GeneratorType)
returned = list(get_response)
self.assertEqual(len(returned), 3)
Expand All @@ -539,8 +548,19 @@ def test_get_partitions(self):
"partition_count": 2,
},
metadata=client._rpc_metadata,
**kwargs,
)

def test_get_partitions(self):
self._get_partitions_helper()

def test_get_partitions_w_retry_timeout(self):
from google.api_core.retry import Retry

retry = Retry(predicate=object())
timeout = 123.0
self._get_partitions_helper(retry=retry, timeout=timeout)

def test_get_partitions_w_filter(self):
# Make a **real** collection reference as parent.
client = _make_client()
Expand Down

0 comments on commit 6e806b0

Please sign in to comment.