diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index 412f32754..7553726fa 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -53,6 +53,8 @@ _PROGRESS_INTERVAL = 0.2 # Maximum time between download status checks, in seconds. +_MAX_QUEUE_SIZE_DEFAULT = object() # max queue size sentinel for BQ Storage downloads + _PANDAS_DTYPE_TO_BQ = { "bool": "BOOLEAN", "datetime64[ns, UTC]": "TIMESTAMP", @@ -616,6 +618,7 @@ def _download_table_bqstorage( preserve_order=False, selected_fields=None, page_to_item=None, + max_queue_size=_MAX_QUEUE_SIZE_DEFAULT, ): """Use (faster, but billable) BQ Storage API to construct DataFrame.""" @@ -667,7 +670,17 @@ def _download_table_bqstorage( download_state = _DownloadState() # Create a queue to collect frames as they are created in each thread. - worker_queue = queue.Queue() + # + # The queue needs to be bounded by default, because if the user code processes the + # fetched result pages too slowly, while at the same time new pages are rapidly being + # fetched from the server, the queue can grow to the point where the process runs + # out of memory. + if max_queue_size is _MAX_QUEUE_SIZE_DEFAULT: + max_queue_size = total_streams + elif max_queue_size is None: + max_queue_size = 0 # unbounded + + worker_queue = queue.Queue(maxsize=max_queue_size) with concurrent.futures.ThreadPoolExecutor(max_workers=total_streams) as pool: try: @@ -708,15 +721,12 @@ def _download_table_bqstorage( continue # Return any remaining values after the workers finished. - while not worker_queue.empty(): # pragma: NO COVER + while True: # pragma: NO COVER try: - # Include a timeout because even though the queue is - # non-empty, it doesn't guarantee that a subsequent call to - # get() will not block. - frame = worker_queue.get(timeout=_PROGRESS_INTERVAL) + frame = worker_queue.get_nowait() yield frame except queue.Empty: # pragma: NO COVER - continue + break finally: # No need for a lock because reading/replacing a variable is # defined to be an atomic operation in the Python language @@ -729,7 +739,7 @@ def _download_table_bqstorage( def download_arrow_bqstorage( - project_id, table, bqstorage_client, preserve_order=False, selected_fields=None + project_id, table, bqstorage_client, preserve_order=False, selected_fields=None, ): return _download_table_bqstorage( project_id, @@ -749,6 +759,7 @@ def download_dataframe_bqstorage( dtypes, preserve_order=False, selected_fields=None, + max_queue_size=_MAX_QUEUE_SIZE_DEFAULT, ): page_to_item = functools.partial(_bqstorage_page_to_dataframe, column_names, dtypes) return _download_table_bqstorage( @@ -758,6 +769,7 @@ def download_dataframe_bqstorage( preserve_order=preserve_order, selected_fields=selected_fields, page_to_item=page_to_item, + max_queue_size=max_queue_size, ) diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index a2366b806..bd5bca30f 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1490,13 +1490,12 @@ def _to_page_iterable( if not self._validate_bqstorage(bqstorage_client, False): bqstorage_client = None - if bqstorage_client is not None: - for item in bqstorage_download(): - yield item - return - - for item in tabledata_list_download(): - yield item + result_pages = ( + bqstorage_download() + if bqstorage_client is not None + else tabledata_list_download() + ) + yield from result_pages def _to_arrow_iterable(self, bqstorage_client=None): """Create an iterable of arrow RecordBatches, to process the table as a stream.""" @@ -1622,7 +1621,12 @@ def to_arrow( arrow_schema = _pandas_helpers.bq_to_arrow_schema(self._schema) return pyarrow.Table.from_batches(record_batches, schema=arrow_schema) - def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None): + def to_dataframe_iterable( + self, + bqstorage_client=None, + dtypes=None, + max_queue_size=_pandas_helpers._MAX_QUEUE_SIZE_DEFAULT, + ): """Create an iterable of pandas DataFrames, to process the table as a stream. Args: @@ -1642,6 +1646,17 @@ def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None): ``dtype`` is used when constructing the series for the column specified. Otherwise, the default pandas behavior is used. + max_queue_size (Optional[int]): + The maximum number of result pages to hold in the internal queue when + streaming query results over the BigQuery Storage API. Ignored if + Storage API is not used. + + By default, the max queue size is set to the number of BQ Storage streams + created by the server. If ``max_queue_size`` is :data:`None`, the queue + size is infinite. + + ..versionadded:: 2.14.0 + Returns: pandas.DataFrame: A generator of :class:`~pandas.DataFrame`. @@ -1665,6 +1680,7 @@ def to_dataframe_iterable(self, bqstorage_client=None, dtypes=None): dtypes, preserve_order=self._preserve_order, selected_fields=self._selected_fields, + max_queue_size=max_queue_size, ) tabledata_list_download = functools.partial( _pandas_helpers.download_dataframe_row_iterator, diff --git a/tests/unit/test__pandas_helpers.py b/tests/unit/test__pandas_helpers.py index abd725820..43692f4af 100644 --- a/tests/unit/test__pandas_helpers.py +++ b/tests/unit/test__pandas_helpers.py @@ -17,6 +17,7 @@ import decimal import functools import operator +import queue import warnings import mock @@ -41,6 +42,11 @@ from google.cloud.bigquery import schema from google.cloud.bigquery._pandas_helpers import _BIGNUMERIC_SUPPORT +try: + from google.cloud import bigquery_storage +except ImportError: # pragma: NO COVER + bigquery_storage = None + skip_if_no_bignumeric = pytest.mark.skipif( not _BIGNUMERIC_SUPPORT, reason="BIGNUMERIC support requires pyarrow>=3.0.0", @@ -1265,6 +1271,66 @@ def test_dataframe_to_parquet_dict_sequence_schema(module_under_test): assert schema_arg == expected_schema_arg +@pytest.mark.parametrize( + "stream_count,maxsize_kwarg,expected_call_count,expected_maxsize", + [ + (3, {"max_queue_size": 2}, 3, 2), # custom queue size + (4, {}, 4, 4), # default queue size + (7, {"max_queue_size": None}, 7, 0), # infinite queue size + ], +) +@pytest.mark.skipif( + bigquery_storage is None, reason="Requires `google-cloud-bigquery-storage`" +) +def test__download_table_bqstorage( + module_under_test, + stream_count, + maxsize_kwarg, + expected_call_count, + expected_maxsize, +): + from google.cloud.bigquery import dataset + from google.cloud.bigquery import table + + queue_used = None # A reference to the queue used by code under test. + + bqstorage_client = mock.create_autospec( + bigquery_storage.BigQueryReadClient, instance=True + ) + fake_session = mock.Mock(streams=["stream/s{i}" for i in range(stream_count)]) + bqstorage_client.create_read_session.return_value = fake_session + + table_ref = table.TableReference( + dataset.DatasetReference("project-x", "dataset-y"), "table-z", + ) + + def fake_download_stream( + download_state, bqstorage_client, session, stream, worker_queue, page_to_item + ): + nonlocal queue_used + queue_used = worker_queue + try: + worker_queue.put_nowait("result_page") + except queue.Full: # pragma: NO COVER + pass + + download_stream = mock.Mock(side_effect=fake_download_stream) + + with mock.patch.object( + module_under_test, "_download_table_bqstorage_stream", new=download_stream + ): + result_gen = module_under_test._download_table_bqstorage( + "some-project", table_ref, bqstorage_client, **maxsize_kwarg + ) + list(result_gen) + + # Timing-safe, as the method under test should block until the pool shutdown is + # complete, at which point all download stream workers have already been submitted + # to the thread pool. + assert download_stream.call_count == stream_count # once for each stream + assert queue_used.maxsize == expected_maxsize + + @pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`") def test_download_arrow_row_iterator_unknown_field_type(module_under_test): fake_page = api_core.page_iterator.Page(