Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add max_queue_size argument to RowIterator.to_dataframe_iterable #575

Merged
merged 5 commits into from Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 20 additions & 8 deletions google/cloud/bigquery/_pandas_helpers.py
Expand Up @@ -45,6 +45,8 @@

_PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds.

_MAX_QUEUE_SIZE_DEFAULT = object() # max queue size sentinel for BQ Storage downloads

_PANDAS_DTYPE_TO_BQ = {
"bool": "BOOLEAN",
"datetime64[ns, UTC]": "TIMESTAMP",
Expand Down Expand Up @@ -608,6 +610,7 @@ def _download_table_bqstorage(
preserve_order=False,
selected_fields=None,
page_to_item=None,
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
):
"""Use (faster, but billable) BQ Storage API to construct DataFrame."""

Expand Down Expand Up @@ -654,7 +657,17 @@ def _download_table_bqstorage(
download_state = _DownloadState()

# Create a queue to collect frames as they are created in each thread.
worker_queue = queue.Queue()
#
# The queue needs to be bounded by default, because if the user code processes the
# fetched result pages too slowly, while at the same time new pages are rapidly being
# fetched from the server, the queue can grow to the point where the process runs
# out of memory.
if max_queue_size is _MAX_QUEUE_SIZE_DEFAULT:
max_queue_size = total_streams
elif max_queue_size is None:
max_queue_size = 0 # unbounded

worker_queue = queue.Queue(maxsize=max_queue_size)

with concurrent.futures.ThreadPoolExecutor(max_workers=total_streams) as pool:
try:
Expand Down Expand Up @@ -695,15 +708,12 @@ def _download_table_bqstorage(
continue

# Return any remaining values after the workers finished.
while not worker_queue.empty(): # pragma: NO COVER
while True: # pragma: NO COVER
try:
# Include a timeout because even though the queue is
# non-empty, it doesn't guarantee that a subsequent call to
# get() will not block.
frame = worker_queue.get(timeout=_PROGRESS_INTERVAL)
frame = worker_queue.get_nowait()
yield frame
except queue.Empty: # pragma: NO COVER
continue
break
finally:
# No need for a lock because reading/replacing a variable is
# defined to be an atomic operation in the Python language
Expand All @@ -716,7 +726,7 @@ def _download_table_bqstorage(


def download_arrow_bqstorage(
project_id, table, bqstorage_client, preserve_order=False, selected_fields=None
project_id, table, bqstorage_client, preserve_order=False, selected_fields=None,
):
return _download_table_bqstorage(
project_id,
Expand All @@ -736,6 +746,7 @@ def download_dataframe_bqstorage(
dtypes,
preserve_order=False,
selected_fields=None,
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
):
page_to_item = functools.partial(_bqstorage_page_to_dataframe, column_names, dtypes)
return _download_table_bqstorage(
Expand All @@ -745,6 +756,7 @@ def download_dataframe_bqstorage(
preserve_order=preserve_order,
selected_fields=selected_fields,
page_to_item=page_to_item,
max_queue_size=max_queue_size,
)


Expand Down
32 changes: 24 additions & 8 deletions google/cloud/bigquery/table.py
Expand Up @@ -1490,13 +1490,12 @@ def _to_page_iterable(
if not self._validate_bqstorage(bqstorage_client, False):
bqstorage_client = None

if bqstorage_client is not None:
for item in bqstorage_download():
yield item
return

for item in tabledata_list_download():
yield item
result_pages = (
bqstorage_download()
if bqstorage_client is not None
else tabledata_list_download()
)
yield from result_pages

def _to_arrow_iterable(self, bqstorage_client=None):
"""Create an iterable of arrow RecordBatches, to process the table as a stream."""
Expand Down Expand Up @@ -1622,7 +1621,12 @@ def to_arrow(
arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema)
return pyarrow.Table.from_batches(record_batches, schema=arrow_schema)

def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None):
def to_dataframe_iterable(
self,
bqstorage_client=None,
dtypes=None,
max_queue_size=_pandas_helpers._MAX_QUEUE_SIZE_DEFAULT,
):
"""Create an iterable of pandas DataFrames, to process the table as a stream.

Args:
Expand All @@ -1642,6 +1646,17 @@ def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None):
``dtype`` is used when constructing the series for the column
specified. Otherwise, the default pandas behavior is used.

max_queue_size (Optional[int]):
The maximum number of result pages to hold in the internal queue when
streaming query results over the BigQuery Storage API. Ignored if
Storage API is not used.

By default, the max queue size is set to the number of BQ Storage streams
created by the server. If ``max_queue_size`` is :data:`None`, the queue
size is infinite.
Comment on lines +1654 to +1656
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just in case somebody really wants the old behavior, I added it as an option.


..versionadded:: 2.14.0

Returns:
pandas.DataFrame:
A generator of :class:`~pandas.DataFrame`.
Expand All @@ -1665,6 +1680,7 @@ def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None):
dtypes,
preserve_order=self._preserve_order,
selected_fields=self._selected_fields,
max_queue_size=max_queue_size,
)
tabledata_list_download = functools.partial(
_pandas_helpers.download_dataframe_row_iterator,
Expand Down
66 changes: 66 additions & 0 deletions tests/unit/test__pandas_helpers.py
Expand Up @@ -17,6 +17,7 @@
import decimal
import functools
import operator
import queue
import warnings

import mock
Expand All @@ -41,6 +42,11 @@
from google.cloud.bigquery import schema
from google.cloud.bigquery._pandas_helpers import _BIGNUMERIC_SUPPORT

try:
from google.cloud import bigquery_storage
except ImportError: # pragma: NO COVER
bigquery_storage = None


skip_if_no_bignumeric = pytest.mark.skipif(
not _BIGNUMERIC_SUPPORT, reason="BIGNUMERIC support requires pyarrow>=3.0.0",
Expand Down Expand Up @@ -1265,6 +1271,66 @@ def test_dataframe_to_parquet_dict_sequence_schema(module_under_test):
assert schema_arg == expected_schema_arg


@pytest.mark.parametrize(
"stream_count,maxsize_kwarg,expected_call_count,expected_maxsize",
[
(3, {"max_queue_size": 2}, 3, 2), # custom queue size
(4, {}, 4, 4), # default queue size
(7, {"max_queue_size": None}, 7, 0), # infinite queue size
],
)
@pytest.mark.skipif(
bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`"
)
def test__download_table_bqstorage(
module_under_test,
stream_count,
maxsize_kwarg,
expected_call_count,
expected_maxsize,
):
from google.cloud.bigquery import dataset
from google.cloud.bigquery import table

queue_used = None # A reference to the queue used by code under test.

bqstorage_client = mock.create_autospec(
bigquery_storage.BigQueryReadClient, instance=True
)
fake_session = mock.Mock(streams=["stream/s{i}" for i in range(stream_count)])
bqstorage_client.create_read_session.return_value = fake_session

table_ref = table.TableReference(
dataset.DatasetReference("project-x", "dataset-y"), "table-z",
)

def fake_download_stream(
download_state, bqstorage_client, session, stream, worker_queue, page_to_item
):
nonlocal queue_used
queue_used = worker_queue
try:
worker_queue.put_nowait("result_page")
except queue.Full: # pragma: NO COVER
pass

download_stream = mock.Mock(side_effect=fake_download_stream)

with mock.patch.object(
module_under_test, "_download_table_bqstorage_stream", new=download_stream
):
result_gen = module_under_test._download_table_bqstorage(
"some-project", table_ref, bqstorage_client, **maxsize_kwarg
)
list(result_gen)

# Timing-safe, as the method under test should block until the pool shutdown is
# complete, at which point all download stream workers have already been submitted
# to the thread pool.
assert download_stream.call_count == stream_count # once for each stream
assert queue_used.maxsize == expected_maxsize


@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
def test_download_arrow_row_iterator_unknown_field_type(module_under_test):
fake_page = api_core.page_iterator.Page(
Expand Down