From 2a9618f4daaa4a014161e1a2f7376844eec9e8da Mon Sep 17 00:00:00 2001 From: Peter Lamut Date: Fri, 25 Jun 2021 08:56:40 +0200 Subject: [PATCH] feat: add max_results parameter to some of the QueryJob methods (#698) * feat: add max_results to a few QueryJob methods It is now possible to cap the number of result rows returned when invoking `to_dataframe()` or `to_arrow()` method on a `QueryJob` instance. * Work around a pytype complaint * Make _EmptyRowIterator a subclass of RowIterator Co-authored-by: Dan Lee <71398022+dandhlee@users.noreply.github.com> --- google/cloud/bigquery/_tqdm_helpers.py | 37 +++++++- google/cloud/bigquery/job/query.py | 22 ++++- google/cloud/bigquery/table.py | 53 ++++++++++- tests/unit/job/test_query_pandas.py | 101 ++++++++++++++++++++- tests/unit/test_signature_compatibility.py | 31 +++++-- tests/unit/test_table.py | 19 ++++ 6 files changed, 240 insertions(+), 23 deletions(-) diff --git a/google/cloud/bigquery/_tqdm_helpers.py b/google/cloud/bigquery/_tqdm_helpers.py index 2fcf2a981..99e720e2b 100644 --- a/google/cloud/bigquery/_tqdm_helpers.py +++ b/google/cloud/bigquery/_tqdm_helpers.py @@ -16,6 +16,8 @@ import concurrent.futures import time +import typing +from typing import Optional import warnings try: @@ -23,6 +25,10 @@ except ImportError: # pragma: NO COVER tqdm = None +if typing.TYPE_CHECKING: # pragma: NO COVER + from google.cloud.bigquery import QueryJob + from google.cloud.bigquery.table import RowIterator + _NO_TQDM_ERROR = ( "A progress bar was requested, but there was an error loading the tqdm " "library. Please install tqdm to use the progress bar functionality." @@ -32,7 +38,7 @@ def get_progress_bar(progress_bar_type, description, total, unit): - """Construct a tqdm progress bar object, if tqdm is .""" + """Construct a tqdm progress bar object, if tqdm is installed.""" if tqdm is None: if progress_bar_type is not None: warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3) @@ -53,16 +59,34 @@ def get_progress_bar(progress_bar_type, description, total, unit): return None -def wait_for_query(query_job, progress_bar_type=None): - """Return query result and display a progress bar while the query running, if tqdm is installed.""" +def wait_for_query( + query_job: "QueryJob", + progress_bar_type: Optional[str] = None, + max_results: Optional[int] = None, +) -> "RowIterator": + """Return query result and display a progress bar while the query running, if tqdm is installed. + + Args: + query_job: + The job representing the execution of the query on the server. + progress_bar_type: + The type of progress bar to use to show query progress. + max_results: + The maximum number of rows the row iterator should return. + + Returns: + A row iterator over the query results. + """ default_total = 1 current_stage = None start_time = time.time() + progress_bar = get_progress_bar( progress_bar_type, "Query is running", default_total, "query" ) if progress_bar is None: - return query_job.result() + return query_job.result(max_results=max_results) + i = 0 while True: if query_job.query_plan: @@ -75,7 +99,9 @@ def wait_for_query(query_job, progress_bar_type=None): ), ) try: - query_result = query_job.result(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + query_result = query_job.result( + timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=max_results + ) progress_bar.update(default_total) progress_bar.set_description( "Query complete after {:0.2f}s".format(time.time() - start_time), @@ -89,5 +115,6 @@ def wait_for_query(query_job, progress_bar_type=None): progress_bar.update(i + 1) i += 1 continue + progress_bar.close() return query_result diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index 455ef4632..6ff9f2647 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -1300,12 +1300,14 @@ def result( return rows # If changing the signature of this method, make sure to apply the same - # changes to table.RowIterator.to_arrow() + # changes to table.RowIterator.to_arrow(), except for the max_results parameter + # that should only exist here in the QueryJob method. def to_arrow( self, progress_bar_type: str = None, bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, create_bqstorage_client: bool = True, + max_results: Optional[int] = None, ) -> "pyarrow.Table": """[Beta] Create a class:`pyarrow.Table` by loading all pages of a table or query. @@ -1349,6 +1351,11 @@ def to_arrow( ..versionadded:: 1.24.0 + max_results (Optional[int]): + Maximum number of rows to include in the result. No limit by default. + + ..versionadded:: 2.21.0 + Returns: pyarrow.Table A :class:`pyarrow.Table` populated with row data and column @@ -1361,7 +1368,7 @@ def to_arrow( ..versionadded:: 1.17.0 """ - query_result = wait_for_query(self, progress_bar_type) + query_result = wait_for_query(self, progress_bar_type, max_results=max_results) return query_result.to_arrow( progress_bar_type=progress_bar_type, bqstorage_client=bqstorage_client, @@ -1369,7 +1376,8 @@ def to_arrow( ) # If changing the signature of this method, make sure to apply the same - # changes to table.RowIterator.to_dataframe() + # changes to table.RowIterator.to_dataframe(), except for the max_results parameter + # that should only exist here in the QueryJob method. def to_dataframe( self, bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, @@ -1377,6 +1385,7 @@ def to_dataframe( progress_bar_type: str = None, create_bqstorage_client: bool = True, date_as_object: bool = True, + max_results: Optional[int] = None, ) -> "pandas.DataFrame": """Return a pandas DataFrame from a QueryJob @@ -1423,6 +1432,11 @@ def to_dataframe( ..versionadded:: 1.26.0 + max_results (Optional[int]): + Maximum number of rows to include in the result. No limit by default. + + ..versionadded:: 2.21.0 + Returns: A :class:`~pandas.DataFrame` populated with row data and column headers from the query results. The column headers are derived @@ -1431,7 +1445,7 @@ def to_dataframe( Raises: ValueError: If the `pandas` library cannot be imported. """ - query_result = wait_for_query(self, progress_bar_type) + query_result = wait_for_query(self, progress_bar_type, max_results=max_results) return query_result.to_dataframe( bqstorage_client=bqstorage_client, dtypes=dtypes, diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index b12209252..a1c13c85d 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -22,7 +22,7 @@ import operator import pytz import typing -from typing import Any, Dict, Iterable, Tuple +from typing import Any, Dict, Iterable, Iterator, Optional, Tuple import warnings try: @@ -1415,7 +1415,9 @@ class RowIterator(HTTPIterator): """A class for iterating through HTTP/JSON API row list responses. Args: - client (google.cloud.bigquery.Client): The API client. + client (Optional[google.cloud.bigquery.Client]): + The API client instance. This should always be non-`None`, except for + subclasses that do not use it, namely the ``_EmptyRowIterator``. api_request (Callable[google.cloud._http.JSONConnection.api_request]): The function to use to make API requests. path (str): The method path to query for the list of items. @@ -1480,7 +1482,7 @@ def __init__( self._field_to_index = _helpers._field_to_index_mapping(schema) self._page_size = page_size self._preserve_order = False - self._project = client.project + self._project = client.project if client is not None else None self._schema = schema self._selected_fields = selected_fields self._table = table @@ -1895,7 +1897,7 @@ def to_dataframe( return df -class _EmptyRowIterator(object): +class _EmptyRowIterator(RowIterator): """An empty row iterator. This class prevents API requests when there are no rows to fetch or rows @@ -1907,6 +1909,18 @@ class _EmptyRowIterator(object): pages = () total_rows = 0 + def __init__( + self, client=None, api_request=None, path=None, schema=(), *args, **kwargs + ): + super().__init__( + client=client, + api_request=api_request, + path=path, + schema=schema, + *args, + **kwargs, + ) + def to_arrow( self, progress_bar_type=None, @@ -1951,6 +1965,37 @@ def to_dataframe( raise ValueError(_NO_PANDAS_ERROR) return pandas.DataFrame() + def to_dataframe_iterable( + self, + bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None, + dtypes: Optional[Dict[str, Any]] = None, + max_queue_size: Optional[int] = None, + ) -> Iterator["pandas.DataFrame"]: + """Create an iterable of pandas DataFrames, to process the table as a stream. + + ..versionadded:: 2.21.0 + + Args: + bqstorage_client: + Ignored. Added for compatibility with RowIterator. + + dtypes (Optional[Map[str, Union[str, pandas.Series.dtype]]]): + Ignored. Added for compatibility with RowIterator. + + max_queue_size: + Ignored. Added for compatibility with RowIterator. + + Returns: + An iterator yielding a single empty :class:`~pandas.DataFrame`. + + Raises: + ValueError: + If the :mod:`pandas` library cannot be imported. + """ + if pandas is None: + raise ValueError(_NO_PANDAS_ERROR) + return iter((pandas.DataFrame(),)) + def __iter__(self): return iter(()) diff --git a/tests/unit/job/test_query_pandas.py b/tests/unit/job/test_query_pandas.py index 0f9623203..c537802f4 100644 --- a/tests/unit/job/test_query_pandas.py +++ b/tests/unit/job/test_query_pandas.py @@ -238,6 +238,41 @@ def test_to_arrow(): ] +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +def test_to_arrow_max_results_no_progress_bar(): + from google.cloud.bigquery import table + from google.cloud.bigquery.job import QueryJob as target_class + from google.cloud.bigquery.schema import SchemaField + + connection = _make_connection({}) + client = _make_client(connection=connection) + begun_resource = _make_job_resource(job_type="query") + job = target_class.from_api_repr(begun_resource, client) + + schema = [ + SchemaField("name", "STRING", mode="REQUIRED"), + SchemaField("age", "INTEGER", mode="REQUIRED"), + ] + rows = [ + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, + ] + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = table.RowIterator(client, api_request, path, schema) + + result_patch = mock.patch( + "google.cloud.bigquery.job.QueryJob.result", return_value=row_iterator, + ) + with result_patch as result_patch_tqdm: + tbl = job.to_arrow(create_bqstorage_client=False, max_results=123) + + result_patch_tqdm.assert_called_once_with(max_results=123) + + assert isinstance(tbl, pyarrow.Table) + assert tbl.num_rows == 2 + + @pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") @pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") def test_to_arrow_w_tqdm_w_query_plan(): @@ -290,7 +325,9 @@ def test_to_arrow_w_tqdm_w_query_plan(): assert result_patch_tqdm.call_count == 3 assert isinstance(tbl, pyarrow.Table) assert tbl.num_rows == 2 - result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + result_patch_tqdm.assert_called_with( + timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None + ) @pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") @@ -341,7 +378,9 @@ def test_to_arrow_w_tqdm_w_pending_status(): assert result_patch_tqdm.call_count == 2 assert isinstance(tbl, pyarrow.Table) assert tbl.num_rows == 2 - result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + result_patch_tqdm.assert_called_with( + timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None + ) @pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") @@ -716,7 +755,9 @@ def test_to_dataframe_w_tqdm_pending(): assert isinstance(df, pandas.DataFrame) assert len(df) == 4 # verify the number of rows assert list(df) == ["name", "age"] # verify the column names - result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + result_patch_tqdm.assert_called_with( + timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None + ) @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @@ -774,4 +815,56 @@ def test_to_dataframe_w_tqdm(): assert isinstance(df, pandas.DataFrame) assert len(df) == 4 # verify the number of rows assert list(df), ["name", "age"] # verify the column names - result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + result_patch_tqdm.assert_called_with( + timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None + ) + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") +def test_to_dataframe_w_tqdm_max_results(): + from google.cloud.bigquery import table + from google.cloud.bigquery.job import QueryJob as target_class + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL + + begun_resource = _make_job_resource(job_type="query") + schema = [ + SchemaField("name", "STRING", mode="NULLABLE"), + SchemaField("age", "INTEGER", mode="NULLABLE"), + ] + rows = [{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}] + + connection = _make_connection({}) + client = _make_client(connection=connection) + job = target_class.from_api_repr(begun_resource, client) + + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = table.RowIterator(client, api_request, path, schema) + + job._properties["statistics"] = { + "query": { + "queryPlan": [ + {"name": "S00: Input", "id": "0", "status": "COMPLETE"}, + {"name": "S01: Output", "id": "1", "status": "COMPLETE"}, + ] + }, + } + reload_patch = mock.patch( + "google.cloud.bigquery.job._AsyncJob.reload", autospec=True + ) + result_patch = mock.patch( + "google.cloud.bigquery.job.QueryJob.result", + side_effect=[concurrent.futures.TimeoutError, row_iterator], + ) + + with result_patch as result_patch_tqdm, reload_patch: + job.to_dataframe( + progress_bar_type="tqdm", create_bqstorage_client=False, max_results=3 + ) + + assert result_patch_tqdm.call_count == 2 + result_patch_tqdm.assert_called_with( + timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=3 + ) diff --git a/tests/unit/test_signature_compatibility.py b/tests/unit/test_signature_compatibility.py index e5016b0e5..07b823e2c 100644 --- a/tests/unit/test_signature_compatibility.py +++ b/tests/unit/test_signature_compatibility.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict import inspect import pytest @@ -32,12 +33,30 @@ def row_iterator_class(): def test_to_arrow_method_signatures_match(query_job_class, row_iterator_class): - sig = inspect.signature(query_job_class.to_arrow) - sig2 = inspect.signature(row_iterator_class.to_arrow) - assert sig == sig2 + query_job_sig = inspect.signature(query_job_class.to_arrow) + iterator_sig = inspect.signature(row_iterator_class.to_arrow) + + assert "max_results" in query_job_sig.parameters + + # Compare the signatures while ignoring the max_results parameter, which is + # specific to the method on QueryJob. + params = OrderedDict(query_job_sig.parameters) + del params["max_results"] + query_job_sig = query_job_sig.replace(parameters=params.values()) + + assert query_job_sig == iterator_sig def test_to_dataframe_method_signatures_match(query_job_class, row_iterator_class): - sig = inspect.signature(query_job_class.to_dataframe) - sig2 = inspect.signature(row_iterator_class.to_dataframe) - assert sig == sig2 + query_job_sig = inspect.signature(query_job_class.to_dataframe) + iterator_sig = inspect.signature(row_iterator_class.to_dataframe) + + assert "max_results" in query_job_sig.parameters + + # Compare the signatures while ignoring the max_results parameter, which is + # specific to the method on QueryJob. + params = OrderedDict(query_job_sig.parameters) + del params["max_results"] + query_job_sig = query_job_sig.replace(parameters=params.values()) + + assert query_job_sig == iterator_sig diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 0f2ab00c1..f4038835c 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -1571,6 +1571,25 @@ def test_to_dataframe(self): self.assertIsInstance(df, pandas.DataFrame) self.assertEqual(len(df), 0) # verify the number of rows + @mock.patch("google.cloud.bigquery.table.pandas", new=None) + def test_to_dataframe_iterable_error_if_pandas_is_none(self): + row_iterator = self._make_one() + with self.assertRaises(ValueError): + row_iterator.to_dataframe_iterable() + + @unittest.skipIf(pandas is None, "Requires `pandas`") + def test_to_dataframe_iterable(self): + row_iterator = self._make_one() + df_iter = row_iterator.to_dataframe_iterable() + + result = list(df_iter) + + self.assertEqual(len(result), 1) + df = result[0] + self.assertIsInstance(df, pandas.DataFrame) + self.assertEqual(len(df), 0) # Verify the number of rows. + self.assertEqual(len(df.columns), 0) + class TestRowIterator(unittest.TestCase): def _class_under_test(self):