Skip to content

Commit

Permalink
feat: add max_results parameter to some of the QueryJob methods (#698)
Browse files Browse the repository at this point in the history
* feat: add max_results to a few QueryJob methods

It is now possible to cap the number of result rows returned when
invoking `to_dataframe()` or `to_arrow()` method on a `QueryJob`
instance.

* Work around a pytype complaint

* Make _EmptyRowIterator a subclass of RowIterator

Co-authored-by: Dan Lee <71398022+dandhlee@users.noreply.github.com>
  • Loading branch information
plamut and dandhlee committed Jun 25, 2021
1 parent b35e1ad commit 2a9618f
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 23 deletions.
37 changes: 32 additions & 5 deletions google/cloud/bigquery/_tqdm_helpers.py
Expand Up @@ -16,13 +16,19 @@

import concurrent.futures
import time
import typing
from typing import Optional
import warnings

try:
import tqdm
except ImportError: # pragma: NO COVER
tqdm = None

if typing.TYPE_CHECKING: # pragma: NO COVER
from google.cloud.bigquery import QueryJob
from google.cloud.bigquery.table import RowIterator

_NO_TQDM_ERROR = (
"A progress bar was requested, but there was an error loading the tqdm "
"library. Please install tqdm to use the progress bar functionality."
Expand All @@ -32,7 +38,7 @@


def get_progress_bar(progress_bar_type, description, total, unit):
"""Construct a tqdm progress bar object, if tqdm is ."""
"""Construct a tqdm progress bar object, if tqdm is installed."""
if tqdm is None:
if progress_bar_type is not None:
warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3)
Expand All @@ -53,16 +59,34 @@ def get_progress_bar(progress_bar_type, description, total, unit):
return None


def wait_for_query(query_job, progress_bar_type=None):
"""Return query result and display a progress bar while the query running, if tqdm is installed."""
def wait_for_query(
query_job: "QueryJob",
progress_bar_type: Optional[str] = None,
max_results: Optional[int] = None,
) -> "RowIterator":
"""Return query result and display a progress bar while the query running, if tqdm is installed.
Args:
query_job:
The job representing the execution of the query on the server.
progress_bar_type:
The type of progress bar to use to show query progress.
max_results:
The maximum number of rows the row iterator should return.
Returns:
A row iterator over the query results.
"""
default_total = 1
current_stage = None
start_time = time.time()

progress_bar = get_progress_bar(
progress_bar_type, "Query is running", default_total, "query"
)
if progress_bar is None:
return query_job.result()
return query_job.result(max_results=max_results)

i = 0
while True:
if query_job.query_plan:
Expand All @@ -75,7 +99,9 @@ def wait_for_query(query_job, progress_bar_type=None):
),
)
try:
query_result = query_job.result(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
query_result = query_job.result(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=max_results
)
progress_bar.update(default_total)
progress_bar.set_description(
"Query complete after {:0.2f}s".format(time.time() - start_time),
Expand All @@ -89,5 +115,6 @@ def wait_for_query(query_job, progress_bar_type=None):
progress_bar.update(i + 1)
i += 1
continue

progress_bar.close()
return query_result
22 changes: 18 additions & 4 deletions google/cloud/bigquery/job/query.py
Expand Up @@ -1300,12 +1300,14 @@ def result(
return rows

# If changing the signature of this method, make sure to apply the same
# changes to table.RowIterator.to_arrow()
# changes to table.RowIterator.to_arrow(), except for the max_results parameter
# that should only exist here in the QueryJob method.
def to_arrow(
self,
progress_bar_type: str = None,
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
create_bqstorage_client: bool = True,
max_results: Optional[int] = None,
) -> "pyarrow.Table":
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
table or query.
Expand Down Expand Up @@ -1349,6 +1351,11 @@ def to_arrow(
..versionadded:: 1.24.0
max_results (Optional[int]):
Maximum number of rows to include in the result. No limit by default.
..versionadded:: 2.21.0
Returns:
pyarrow.Table
A :class:`pyarrow.Table` populated with row data and column
Expand All @@ -1361,22 +1368,24 @@ def to_arrow(
..versionadded:: 1.17.0
"""
query_result = wait_for_query(self, progress_bar_type)
query_result = wait_for_query(self, progress_bar_type, max_results=max_results)
return query_result.to_arrow(
progress_bar_type=progress_bar_type,
bqstorage_client=bqstorage_client,
create_bqstorage_client=create_bqstorage_client,
)

# If changing the signature of this method, make sure to apply the same
# changes to table.RowIterator.to_dataframe()
# changes to table.RowIterator.to_dataframe(), except for the max_results parameter
# that should only exist here in the QueryJob method.
def to_dataframe(
self,
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
dtypes: Dict[str, Any] = None,
progress_bar_type: str = None,
create_bqstorage_client: bool = True,
date_as_object: bool = True,
max_results: Optional[int] = None,
) -> "pandas.DataFrame":
"""Return a pandas DataFrame from a QueryJob
Expand Down Expand Up @@ -1423,6 +1432,11 @@ def to_dataframe(
..versionadded:: 1.26.0
max_results (Optional[int]):
Maximum number of rows to include in the result. No limit by default.
..versionadded:: 2.21.0
Returns:
A :class:`~pandas.DataFrame` populated with row data and column
headers from the query results. The column headers are derived
Expand All @@ -1431,7 +1445,7 @@ def to_dataframe(
Raises:
ValueError: If the `pandas` library cannot be imported.
"""
query_result = wait_for_query(self, progress_bar_type)
query_result = wait_for_query(self, progress_bar_type, max_results=max_results)
return query_result.to_dataframe(
bqstorage_client=bqstorage_client,
dtypes=dtypes,
Expand Down
53 changes: 49 additions & 4 deletions google/cloud/bigquery/table.py
Expand Up @@ -22,7 +22,7 @@
import operator
import pytz
import typing
from typing import Any, Dict, Iterable, Tuple
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple
import warnings

try:
Expand Down Expand Up @@ -1415,7 +1415,9 @@ class RowIterator(HTTPIterator):
"""A class for iterating through HTTP/JSON API row list responses.
Args:
client (google.cloud.bigquery.Client): The API client.
client (Optional[google.cloud.bigquery.Client]):
The API client instance. This should always be non-`None`, except for
subclasses that do not use it, namely the ``_EmptyRowIterator``.
api_request (Callable[google.cloud._http.JSONConnection.api_request]):
The function to use to make API requests.
path (str): The method path to query for the list of items.
Expand Down Expand Up @@ -1480,7 +1482,7 @@ def __init__(
self._field_to_index = _helpers._field_to_index_mapping(schema)
self._page_size = page_size
self._preserve_order = False
self._project = client.project
self._project = client.project if client is not None else None
self._schema = schema
self._selected_fields = selected_fields
self._table = table
Expand Down Expand Up @@ -1895,7 +1897,7 @@ def to_dataframe(
return df


class _EmptyRowIterator(object):
class _EmptyRowIterator(RowIterator):
"""An empty row iterator.
This class prevents API requests when there are no rows to fetch or rows
Expand All @@ -1907,6 +1909,18 @@ class _EmptyRowIterator(object):
pages = ()
total_rows = 0

def __init__(
self, client=None, api_request=None, path=None, schema=(), *args, **kwargs
):
super().__init__(
client=client,
api_request=api_request,
path=path,
schema=schema,
*args,
**kwargs,
)

def to_arrow(
self,
progress_bar_type=None,
Expand Down Expand Up @@ -1951,6 +1965,37 @@ def to_dataframe(
raise ValueError(_NO_PANDAS_ERROR)
return pandas.DataFrame()

def to_dataframe_iterable(
self,
bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None,
dtypes: Optional[Dict[str, Any]] = None,
max_queue_size: Optional[int] = None,
) -> Iterator["pandas.DataFrame"]:
"""Create an iterable of pandas DataFrames, to process the table as a stream.
..versionadded:: 2.21.0
Args:
bqstorage_client:
Ignored. Added for compatibility with RowIterator.
dtypes (Optional[Map[str, Union[str, pandas.Series.dtype]]]):
Ignored. Added for compatibility with RowIterator.
max_queue_size:
Ignored. Added for compatibility with RowIterator.
Returns:
An iterator yielding a single empty :class:`~pandas.DataFrame`.
Raises:
ValueError:
If the :mod:`pandas` library cannot be imported.
"""
if pandas is None:
raise ValueError(_NO_PANDAS_ERROR)
return iter((pandas.DataFrame(),))

def __iter__(self):
return iter(())

Expand Down
101 changes: 97 additions & 4 deletions tests/unit/job/test_query_pandas.py
Expand Up @@ -238,6 +238,41 @@ def test_to_arrow():
]


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_to_arrow_max_results_no_progress_bar():
from google.cloud.bigquery import table
from google.cloud.bigquery.job import QueryJob as target_class
from google.cloud.bigquery.schema import SchemaField

connection = _make_connection({})
client = _make_client(connection=connection)
begun_resource = _make_job_resource(job_type="query")
job = target_class.from_api_repr(begun_resource, client)

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
SchemaField("age", "INTEGER", mode="REQUIRED"),
]
rows = [
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
{"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]},
]
path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
row_iterator = table.RowIterator(client, api_request, path, schema)

result_patch = mock.patch(
"google.cloud.bigquery.job.QueryJob.result", return_value=row_iterator,
)
with result_patch as result_patch_tqdm:
tbl = job.to_arrow(create_bqstorage_client=False, max_results=123)

result_patch_tqdm.assert_called_once_with(max_results=123)

assert isinstance(tbl, pyarrow.Table)
assert tbl.num_rows == 2


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`")
def test_to_arrow_w_tqdm_w_query_plan():
Expand Down Expand Up @@ -290,7 +325,9 @@ def test_to_arrow_w_tqdm_w_query_plan():
assert result_patch_tqdm.call_count == 3
assert isinstance(tbl, pyarrow.Table)
assert tbl.num_rows == 2
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
Expand Down Expand Up @@ -341,7 +378,9 @@ def test_to_arrow_w_tqdm_w_pending_status():
assert result_patch_tqdm.call_count == 2
assert isinstance(tbl, pyarrow.Table)
assert tbl.num_rows == 2
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
Expand Down Expand Up @@ -716,7 +755,9 @@ def test_to_dataframe_w_tqdm_pending():
assert isinstance(df, pandas.DataFrame)
assert len(df) == 4 # verify the number of rows
assert list(df) == ["name", "age"] # verify the column names
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
Expand Down Expand Up @@ -774,4 +815,56 @@ def test_to_dataframe_w_tqdm():
assert isinstance(df, pandas.DataFrame)
assert len(df) == 4 # verify the number of rows
assert list(df), ["name", "age"] # verify the column names
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`")
def test_to_dataframe_w_tqdm_max_results():
from google.cloud.bigquery import table
from google.cloud.bigquery.job import QueryJob as target_class
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL

begun_resource = _make_job_resource(job_type="query")
schema = [
SchemaField("name", "STRING", mode="NULLABLE"),
SchemaField("age", "INTEGER", mode="NULLABLE"),
]
rows = [{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}]

connection = _make_connection({})
client = _make_client(connection=connection)
job = target_class.from_api_repr(begun_resource, client)

path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
row_iterator = table.RowIterator(client, api_request, path, schema)

job._properties["statistics"] = {
"query": {
"queryPlan": [
{"name": "S00: Input", "id": "0", "status": "COMPLETE"},
{"name": "S01: Output", "id": "1", "status": "COMPLETE"},
]
},
}
reload_patch = mock.patch(
"google.cloud.bigquery.job._AsyncJob.reload", autospec=True
)
result_patch = mock.patch(
"google.cloud.bigquery.job.QueryJob.result",
side_effect=[concurrent.futures.TimeoutError, row_iterator],
)

with result_patch as result_patch_tqdm, reload_patch:
job.to_dataframe(
progress_bar_type="tqdm", create_bqstorage_client=False, max_results=3
)

assert result_patch_tqdm.call_count == 2
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=3
)

0 comments on commit 2a9618f

Please sign in to comment.