Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add max_results parameter to some of the QueryJob methods #698

Merged
merged 6 commits into from Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 32 additions & 5 deletions google/cloud/bigquery/_tqdm_helpers.py
Expand Up @@ -16,13 +16,19 @@

import concurrent.futures
import time
import typing
from typing import Optional
import warnings

try:
import tqdm
except ImportError: # pragma: NO COVER
tqdm = None

if typing.TYPE_CHECKING: # pragma: NO COVER
from google.cloud.bigquery import QueryJob
from google.cloud.bigquery.table import RowIterator

_NO_TQDM_ERROR = (
"A progress bar was requested, but there was an error loading the tqdm "
"library. Please install tqdm to use the progress bar functionality."
Expand All @@ -32,7 +38,7 @@


def get_progress_bar(progress_bar_type, description, total, unit):
"""Construct a tqdm progress bar object, if tqdm is ."""
"""Construct a tqdm progress bar object, if tqdm is installed."""
if tqdm is None:
if progress_bar_type is not None:
warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3)
Expand All @@ -53,16 +59,34 @@ def get_progress_bar(progress_bar_type, description, total, unit):
return None


def wait_for_query(query_job, progress_bar_type=None):
"""Return query result and display a progress bar while the query running, if tqdm is installed."""
def wait_for_query(
query_job: "QueryJob",
progress_bar_type: Optional[str] = None,
max_results: Optional[int] = None,
) -> "RowIterator":
"""Return query result and display a progress bar while the query running, if tqdm is installed.

Args:
query_job:
The job representing the execution of the query on the server.
progress_bar_type:
The type of progress bar to use to show query progress.
max_results:
The maximum number of rows the row iterator should return.

Returns:
A row iterator over the query results.
"""
default_total = 1
current_stage = None
start_time = time.time()

progress_bar = get_progress_bar(
progress_bar_type, "Query is running", default_total, "query"
)
if progress_bar is None:
return query_job.result()
return query_job.result(max_results=max_results)

i = 0
while True:
if query_job.query_plan:
Expand All @@ -75,7 +99,9 @@ def wait_for_query(query_job, progress_bar_type=None):
),
)
try:
query_result = query_job.result(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
query_result = query_job.result(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=max_results
)
progress_bar.update(default_total)
progress_bar.set_description(
"Query complete after {:0.2f}s".format(time.time() - start_time),
Expand All @@ -89,5 +115,6 @@ def wait_for_query(query_job, progress_bar_type=None):
progress_bar.update(i + 1)
i += 1
continue

progress_bar.close()
return query_result
22 changes: 18 additions & 4 deletions google/cloud/bigquery/job/query.py
Expand Up @@ -1300,12 +1300,14 @@ def result(
return rows

# If changing the signature of this method, make sure to apply the same
# changes to table.RowIterator.to_arrow()
# changes to table.RowIterator.to_arrow(), except for the max_results parameter
# that should only exist here in the QueryJob method.
def to_arrow(
self,
progress_bar_type: str = None,
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
create_bqstorage_client: bool = True,
max_results: Optional[int] = None,
) -> "pyarrow.Table":
"""[Beta] Create a class:`pyarrow.Table` by loading all pages of a
table or query.
Expand Down Expand Up @@ -1349,6 +1351,11 @@ def to_arrow(

..versionadded:: 1.24.0

max_results (Optional[int]):
Maximum number of rows to include in the result. No limit by default.

..versionadded:: 2.21.0

Returns:
pyarrow.Table
A :class:`pyarrow.Table` populated with row data and column
Expand All @@ -1361,22 +1368,24 @@ def to_arrow(

..versionadded:: 1.17.0
"""
query_result = wait_for_query(self, progress_bar_type)
query_result = wait_for_query(self, progress_bar_type, max_results=max_results)
return query_result.to_arrow(
progress_bar_type=progress_bar_type,
bqstorage_client=bqstorage_client,
create_bqstorage_client=create_bqstorage_client,
)

# If changing the signature of this method, make sure to apply the same
# changes to table.RowIterator.to_dataframe()
# changes to table.RowIterator.to_dataframe(), except for the max_results parameter
# that should only exist here in the QueryJob method.
def to_dataframe(
self,
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
dtypes: Dict[str, Any] = None,
progress_bar_type: str = None,
create_bqstorage_client: bool = True,
date_as_object: bool = True,
max_results: Optional[int] = None,
) -> "pandas.DataFrame":
"""Return a pandas DataFrame from a QueryJob

Expand Down Expand Up @@ -1423,6 +1432,11 @@ def to_dataframe(

..versionadded:: 1.26.0

max_results (Optional[int]):
Maximum number of rows to include in the result. No limit by default.

..versionadded:: 2.21.0

Returns:
A :class:`~pandas.DataFrame` populated with row data and column
headers from the query results. The column headers are derived
Expand All @@ -1431,7 +1445,7 @@ def to_dataframe(
Raises:
ValueError: If the `pandas` library cannot be imported.
"""
query_result = wait_for_query(self, progress_bar_type)
query_result = wait_for_query(self, progress_bar_type, max_results=max_results)
return query_result.to_dataframe(
bqstorage_client=bqstorage_client,
dtypes=dtypes,
Expand Down
53 changes: 49 additions & 4 deletions google/cloud/bigquery/table.py
Expand Up @@ -22,7 +22,7 @@
import operator
import pytz
import typing
from typing import Any, Dict, Iterable, Tuple
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple
import warnings

try:
Expand Down Expand Up @@ -1415,7 +1415,9 @@ class RowIterator(HTTPIterator):
"""A class for iterating through HTTP/JSON API row list responses.

Args:
client (google.cloud.bigquery.Client): The API client.
client (Optional[google.cloud.bigquery.Client]):
The API client instance. This should always be non-`None`, except for
subclasses that do not use it, namely the ``_EmptyRowIterator``.
api_request (Callable[google.cloud._http.JSONConnection.api_request]):
The function to use to make API requests.
path (str): The method path to query for the list of items.
Expand Down Expand Up @@ -1480,7 +1482,7 @@ def __init__(
self._field_to_index = _helpers._field_to_index_mapping(schema)
self._page_size = page_size
self._preserve_order = False
self._project = client.project
self._project = client.project if client is not None else None
self._schema = schema
self._selected_fields = selected_fields
self._table = table
Expand Down Expand Up @@ -1895,7 +1897,7 @@ def to_dataframe(
return df


class _EmptyRowIterator(object):
class _EmptyRowIterator(RowIterator):
"""An empty row iterator.

This class prevents API requests when there are no rows to fetch or rows
Expand All @@ -1907,6 +1909,18 @@ class _EmptyRowIterator(object):
pages = ()
total_rows = 0

def __init__(
self, client=None, api_request=None, path=None, schema=(), *args, **kwargs
):
super().__init__(
client=client,
api_request=api_request,
path=path,
schema=schema,
*args,
**kwargs,
)

def to_arrow(
self,
progress_bar_type=None,
Expand Down Expand Up @@ -1951,6 +1965,37 @@ def to_dataframe(
raise ValueError(_NO_PANDAS_ERROR)
return pandas.DataFrame()

def to_dataframe_iterable(
self,
bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None,
dtypes: Optional[Dict[str, Any]] = None,
max_queue_size: Optional[int] = None,
) -> Iterator["pandas.DataFrame"]:
"""Create an iterable of pandas DataFrames, to process the table as a stream.

..versionadded:: 2.21.0

Args:
bqstorage_client:
Ignored. Added for compatibility with RowIterator.

dtypes (Optional[Map[str, Union[str, pandas.Series.dtype]]]):
Ignored. Added for compatibility with RowIterator.

max_queue_size:
Ignored. Added for compatibility with RowIterator.

Returns:
An iterator yielding a single empty :class:`~pandas.DataFrame`.

Raises:
ValueError:
If the :mod:`pandas` library cannot be imported.
"""
if pandas is None:
raise ValueError(_NO_PANDAS_ERROR)
return iter((pandas.DataFrame(),))

def __iter__(self):
return iter(())

Expand Down
101 changes: 97 additions & 4 deletions tests/unit/job/test_query_pandas.py
Expand Up @@ -238,6 +238,41 @@ def test_to_arrow():
]


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_to_arrow_max_results_no_progress_bar():
from google.cloud.bigquery import table
from google.cloud.bigquery.job import QueryJob as target_class
from google.cloud.bigquery.schema import SchemaField

connection = _make_connection({})
client = _make_client(connection=connection)
begun_resource = _make_job_resource(job_type="query")
job = target_class.from_api_repr(begun_resource, client)

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
SchemaField("age", "INTEGER", mode="REQUIRED"),
]
rows = [
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
{"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]},
]
path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
row_iterator = table.RowIterator(client, api_request, path, schema)

result_patch = mock.patch(
"google.cloud.bigquery.job.QueryJob.result", return_value=row_iterator,
)
with result_patch as result_patch_tqdm:
tbl = job.to_arrow(create_bqstorage_client=False, max_results=123)

result_patch_tqdm.assert_called_once_with(max_results=123)

assert isinstance(tbl, pyarrow.Table)
assert tbl.num_rows == 2


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`")
def test_to_arrow_w_tqdm_w_query_plan():
Expand Down Expand Up @@ -290,7 +325,9 @@ def test_to_arrow_w_tqdm_w_query_plan():
assert result_patch_tqdm.call_count == 3
assert isinstance(tbl, pyarrow.Table)
assert tbl.num_rows == 2
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
Expand Down Expand Up @@ -341,7 +378,9 @@ def test_to_arrow_w_tqdm_w_pending_status():
assert result_patch_tqdm.call_count == 2
assert isinstance(tbl, pyarrow.Table)
assert tbl.num_rows == 2
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
Expand Down Expand Up @@ -716,7 +755,9 @@ def test_to_dataframe_w_tqdm_pending():
assert isinstance(df, pandas.DataFrame)
assert len(df) == 4 # verify the number of rows
assert list(df) == ["name", "age"] # verify the column names
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
Expand Down Expand Up @@ -774,4 +815,56 @@ def test_to_dataframe_w_tqdm():
assert isinstance(df, pandas.DataFrame)
assert len(df) == 4 # verify the number of rows
assert list(df), ["name", "age"] # verify the column names
result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL)
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=None
)


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`")
def test_to_dataframe_w_tqdm_max_results():
from google.cloud.bigquery import table
from google.cloud.bigquery.job import QueryJob as target_class
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL

begun_resource = _make_job_resource(job_type="query")
schema = [
SchemaField("name", "STRING", mode="NULLABLE"),
SchemaField("age", "INTEGER", mode="NULLABLE"),
]
rows = [{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}]

connection = _make_connection({})
client = _make_client(connection=connection)
job = target_class.from_api_repr(begun_resource, client)

path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
row_iterator = table.RowIterator(client, api_request, path, schema)

job._properties["statistics"] = {
"query": {
"queryPlan": [
{"name": "S00: Input", "id": "0", "status": "COMPLETE"},
{"name": "S01: Output", "id": "1", "status": "COMPLETE"},
]
},
}
reload_patch = mock.patch(
"google.cloud.bigquery.job._AsyncJob.reload", autospec=True
)
result_patch = mock.patch(
"google.cloud.bigquery.job.QueryJob.result",
side_effect=[concurrent.futures.TimeoutError, row_iterator],
)

with result_patch as result_patch_tqdm, reload_patch:
job.to_dataframe(
progress_bar_type="tqdm", create_bqstorage_client=False, max_results=3
)

assert result_patch_tqdm.call_count == 2
result_patch_tqdm.assert_called_with(
timeout=_PROGRESS_BAR_UPDATE_INTERVAL, max_results=3
)