diff --git a/google/cloud/bigquery/_tqdm_helpers.py b/google/cloud/bigquery/_tqdm_helpers.py new file mode 100644 index 000000000..bdecefe4a --- /dev/null +++ b/google/cloud/bigquery/_tqdm_helpers.py @@ -0,0 +1,94 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helper functions for tqdm progress bar.""" + +import concurrent.futures +import time +import warnings + +try: + import tqdm +except ImportError: # pragma: NO COVER + tqdm = None + +_NO_TQDM_ERROR = ( + "A progress bar was requested, but there was an error loading the tqdm " + "library. Please install tqdm to use the progress bar functionality." +) + +_PROGRESS_BAR_UPDATE_INTERVAL = 0.5 + + +def get_progress_bar(progress_bar_type, description, total, unit): + """Construct a tqdm progress bar object, if tqdm is .""" + if tqdm is None: + if progress_bar_type is not None: + warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3) + return None + + try: + if progress_bar_type == "tqdm": + return tqdm.tqdm(desc=description, total=total, unit=unit) + elif progress_bar_type == "tqdm_notebook": + return tqdm.tqdm_notebook(desc=description, total=total, unit=unit) + elif progress_bar_type == "tqdm_gui": + return tqdm.tqdm_gui(desc=description, total=total, unit=unit) + except (KeyError, TypeError): + # Protect ourselves from any tqdm errors. In case of + # unexpected tqdm behavior, just fall back to showing + # no progress bar. + warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3) + return None + + +def wait_for_query(query_job, progress_bar_type=None): + """Return query result and display a progress bar while the query running, if tqdm is installed.""" + if progress_bar_type is None: + return query_job.result() + + default_total = 1 + current_stage = None + start_time = time.time() + progress_bar = get_progress_bar( + progress_bar_type, "Query is running", default_total, "query" + ) + i = 0 + while True: + if query_job.query_plan: + default_total = len(query_job.query_plan) + current_stage = query_job.query_plan[i] + progress_bar.total = len(query_job.query_plan) + progress_bar.set_description( + "Query executing stage {} and status {} : {:0.2f}s".format( + current_stage.name, current_stage.status, time.time() - start_time, + ), + ) + try: + query_result = query_job.result(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + progress_bar.update(default_total) + progress_bar.set_description( + "Query complete after {:0.2f}s".format(time.time() - start_time), + ) + break + except concurrent.futures.TimeoutError: + query_job.reload() # Refreshes the state via a GET request. + if current_stage: + if current_stage.status == "COMPLETE": + if i < default_total - 1: + progress_bar.update(i + 1) + i += 1 + continue + progress_bar.close() + return query_result diff --git a/google/cloud/bigquery/job/query.py b/google/cloud/bigquery/job/query.py index 6c9221043..7a1a74954 100644 --- a/google/cloud/bigquery/job/query.py +++ b/google/cloud/bigquery/job/query.py @@ -40,6 +40,7 @@ from google.cloud.bigquery.table import _table_arg_to_table_ref from google.cloud.bigquery.table import TableReference from google.cloud.bigquery.table import TimePartitioning +from google.cloud.bigquery._tqdm_helpers import wait_for_query from google.cloud.bigquery.job.base import _AsyncJob from google.cloud.bigquery.job.base import _DONE_STATE @@ -1259,7 +1260,8 @@ def to_arrow( ..versionadded:: 1.17.0 """ - return self.result().to_arrow( + query_result = wait_for_query(self, progress_bar_type) + return query_result.to_arrow( progress_bar_type=progress_bar_type, bqstorage_client=bqstorage_client, create_bqstorage_client=create_bqstorage_client, @@ -1328,7 +1330,8 @@ def to_dataframe( Raises: ValueError: If the `pandas` library cannot be imported. """ - return self.result().to_dataframe( + query_result = wait_for_query(self, progress_bar_type) + return query_result.to_dataframe( bqstorage_client=bqstorage_client, dtypes=dtypes, progress_bar_type=progress_bar_type, diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 1ee36c7ea..4bfedd758 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -36,11 +36,6 @@ except ImportError: # pragma: NO COVER pyarrow = None -try: - import tqdm -except ImportError: # pragma: NO COVER - tqdm = None - import google.api_core.exceptions from google.api_core.page_iterator import HTTPIterator @@ -50,6 +45,7 @@ from google.cloud.bigquery.schema import _build_schema_resource from google.cloud.bigquery.schema import _parse_schema_resource from google.cloud.bigquery.schema import _to_schema_fields +from google.cloud.bigquery._tqdm_helpers import get_progress_bar from google.cloud.bigquery.external_config import ExternalConfig from google.cloud.bigquery.encryption_configuration import EncryptionConfiguration @@ -68,10 +64,7 @@ "The pyarrow library is not installed, please install " "pyarrow to use the to_arrow() function." ) -_NO_TQDM_ERROR = ( - "A progress bar was requested, but there was an error loading the tqdm " - "library. Please install tqdm to use the progress bar functionality." -) + _TABLE_HAS_NO_SCHEMA = 'Table has no schema: call "client.get_table()"' @@ -1418,32 +1411,6 @@ def total_rows(self): """int: The total number of rows in the table.""" return self._total_rows - def _get_progress_bar(self, progress_bar_type): - """Construct a tqdm progress bar object, if tqdm is installed.""" - if tqdm is None: - if progress_bar_type is not None: - warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3) - return None - - description = "Downloading" - unit = "rows" - - try: - if progress_bar_type == "tqdm": - return tqdm.tqdm(desc=description, total=self.total_rows, unit=unit) - elif progress_bar_type == "tqdm_notebook": - return tqdm.tqdm_notebook( - desc=description, total=self.total_rows, unit=unit - ) - elif progress_bar_type == "tqdm_gui": - return tqdm.tqdm_gui(desc=description, total=self.total_rows, unit=unit) - except (KeyError, TypeError): - # Protect ourselves from any tqdm errors. In case of - # unexpected tqdm behavior, just fall back to showing - # no progress bar. - warnings.warn(_NO_TQDM_ERROR, UserWarning, stacklevel=3) - return None - def _to_page_iterable( self, bqstorage_download, tabledata_list_download, bqstorage_client=None ): @@ -1551,7 +1518,9 @@ def to_arrow( owns_bqstorage_client = bqstorage_client is not None try: - progress_bar = self._get_progress_bar(progress_bar_type) + progress_bar = get_progress_bar( + progress_bar_type, "Downloading", self.total_rows, "rows" + ) record_batches = [] for record_batch in self._to_arrow_iterable( diff --git a/tests/unit/job/test_query_pandas.py b/tests/unit/job/test_query_pandas.py index a481bff69..f9d823eb0 100644 --- a/tests/unit/job/test_query_pandas.py +++ b/tests/unit/job/test_query_pandas.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import copy import json @@ -225,6 +226,154 @@ def test_to_arrow(method_kwargs): ] +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") +def test_to_arrow_w_tqdm_w_query_plan(): + from google.cloud.bigquery import table + from google.cloud.bigquery.job import QueryJob as target_class + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL + + begun_resource = _make_job_resource(job_type="query") + rows = [ + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, + ] + + schema = [ + SchemaField("name", "STRING", mode="REQUIRED"), + SchemaField("age", "INTEGER", mode="REQUIRED"), + ] + connection = _make_connection({}) + client = _make_client(connection=connection) + job = target_class.from_api_repr(begun_resource, client) + + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = table.RowIterator(client, api_request, path, schema) + + job._properties["statistics"] = { + "query": { + "queryPlan": [ + {"name": "S00: Input", "id": "0", "status": "COMPLETE"}, + {"name": "S01: Output", "id": "1", "status": "COMPLETE"}, + ] + }, + } + reload_patch = mock.patch( + "google.cloud.bigquery.job._AsyncJob.reload", autospec=True + ) + result_patch = mock.patch( + "google.cloud.bigquery.job.QueryJob.result", + side_effect=[ + concurrent.futures.TimeoutError, + concurrent.futures.TimeoutError, + row_iterator, + ], + ) + + with result_patch as result_patch_tqdm, reload_patch: + tbl = job.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=False) + + assert result_patch_tqdm.call_count == 3 + assert isinstance(tbl, pyarrow.Table) + assert tbl.num_rows == 2 + result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + + +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") +def test_to_arrow_w_tqdm_w_pending_status(): + from google.cloud.bigquery import table + from google.cloud.bigquery.job import QueryJob as target_class + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL + + begun_resource = _make_job_resource(job_type="query") + rows = [ + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, + ] + + schema = [ + SchemaField("name", "STRING", mode="REQUIRED"), + SchemaField("age", "INTEGER", mode="REQUIRED"), + ] + connection = _make_connection({}) + client = _make_client(connection=connection) + job = target_class.from_api_repr(begun_resource, client) + + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = table.RowIterator(client, api_request, path, schema) + + job._properties["statistics"] = { + "query": { + "queryPlan": [ + {"name": "S00: Input", "id": "0", "status": "PENDING"}, + {"name": "S00: Input", "id": "1", "status": "COMPLETE"}, + ] + }, + } + reload_patch = mock.patch( + "google.cloud.bigquery.job._AsyncJob.reload", autospec=True + ) + result_patch = mock.patch( + "google.cloud.bigquery.job.QueryJob.result", + side_effect=[concurrent.futures.TimeoutError, row_iterator], + ) + + with result_patch as result_patch_tqdm, reload_patch: + tbl = job.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=False) + + assert result_patch_tqdm.call_count == 2 + assert isinstance(tbl, pyarrow.Table) + assert tbl.num_rows == 2 + result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + + +@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`") +@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") +def test_to_arrow_w_tqdm_wo_query_plan(): + from google.cloud.bigquery import table + from google.cloud.bigquery.job import QueryJob as target_class + from google.cloud.bigquery.schema import SchemaField + + begun_resource = _make_job_resource(job_type="query") + rows = [ + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, + ] + + schema = [ + SchemaField("name", "STRING", mode="REQUIRED"), + SchemaField("age", "INTEGER", mode="REQUIRED"), + ] + connection = _make_connection({}) + client = _make_client(connection=connection) + job = target_class.from_api_repr(begun_resource, client) + + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = table.RowIterator(client, api_request, path, schema) + + reload_patch = mock.patch( + "google.cloud.bigquery.job._AsyncJob.reload", autospec=True + ) + result_patch = mock.patch( + "google.cloud.bigquery.job.QueryJob.result", + side_effect=[concurrent.futures.TimeoutError, row_iterator], + ) + + with result_patch as result_patch_tqdm, reload_patch: + tbl = job.to_arrow(progress_bar_type="tqdm", create_bqstorage_client=False) + + assert result_patch_tqdm.call_count == 2 + assert isinstance(tbl, pyarrow.Table) + assert tbl.num_rows == 2 + result_patch_tqdm.assert_called() + + @pytest.mark.skipif(pandas is None, reason="Requires `pandas`") @pytest.mark.parametrize( "method_kwargs", @@ -460,3 +609,115 @@ def test_to_dataframe_with_progress_bar(tqdm_mock): job.to_dataframe(progress_bar_type="tqdm", create_bqstorage_client=False) tqdm_mock.assert_called() + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") +def test_to_dataframe_w_tqdm_pending(): + from google.cloud.bigquery import table + from google.cloud.bigquery.job import QueryJob as target_class + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL + + begun_resource = _make_job_resource(job_type="query") + schema = [ + SchemaField("name", "STRING", mode="NULLABLE"), + SchemaField("age", "INTEGER", mode="NULLABLE"), + ] + rows = [ + {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, + {"f": [{"v": "Bhettye Rhubble"}, {"v": "27"}]}, + ] + + connection = _make_connection({}) + client = _make_client(connection=connection) + job = target_class.from_api_repr(begun_resource, client) + + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = table.RowIterator(client, api_request, path, schema) + + job._properties["statistics"] = { + "query": { + "queryPlan": [ + {"name": "S00: Input", "id": "0", "status": "PRNDING"}, + {"name": "S01: Output", "id": "1", "status": "COMPLETE"}, + ] + }, + } + reload_patch = mock.patch( + "google.cloud.bigquery.job._AsyncJob.reload", autospec=True + ) + result_patch = mock.patch( + "google.cloud.bigquery.job.QueryJob.result", + side_effect=[concurrent.futures.TimeoutError, row_iterator], + ) + + with result_patch as result_patch_tqdm, reload_patch: + df = job.to_dataframe(progress_bar_type="tqdm", create_bqstorage_client=False) + + assert result_patch_tqdm.call_count == 2 + assert isinstance(df, pandas.DataFrame) + assert len(df) == 4 # verify the number of rows + assert list(df) == ["name", "age"] # verify the column names + result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) + + +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +@pytest.mark.skipif(tqdm is None, reason="Requires `tqdm`") +def test_to_dataframe_w_tqdm(): + from google.cloud.bigquery import table + from google.cloud.bigquery.job import QueryJob as target_class + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery._tqdm_helpers import _PROGRESS_BAR_UPDATE_INTERVAL + + begun_resource = _make_job_resource(job_type="query") + schema = [ + SchemaField("name", "STRING", mode="NULLABLE"), + SchemaField("age", "INTEGER", mode="NULLABLE"), + ] + rows = [ + {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + {"f": [{"v": "Wylma Phlyntstone"}, {"v": "29"}]}, + {"f": [{"v": "Bhettye Rhubble"}, {"v": "27"}]}, + ] + + connection = _make_connection({}) + client = _make_client(connection=connection) + job = target_class.from_api_repr(begun_resource, client) + + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = table.RowIterator(client, api_request, path, schema) + + job._properties["statistics"] = { + "query": { + "queryPlan": [ + {"name": "S00: Input", "id": "0", "status": "COMPLETE"}, + {"name": "S01: Output", "id": "1", "status": "COMPLETE"}, + ] + }, + } + reload_patch = mock.patch( + "google.cloud.bigquery.job._AsyncJob.reload", autospec=True + ) + result_patch = mock.patch( + "google.cloud.bigquery.job.QueryJob.result", + side_effect=[ + concurrent.futures.TimeoutError, + concurrent.futures.TimeoutError, + row_iterator, + ], + ) + + with result_patch as result_patch_tqdm, reload_patch: + df = job.to_dataframe(progress_bar_type="tqdm", create_bqstorage_client=False) + + assert result_patch_tqdm.call_count == 3 + assert isinstance(df, pandas.DataFrame) + assert len(df) == 4 # verify the number of rows + assert list(df), ["name", "age"] # verify the column names + result_patch_tqdm.assert_called_with(timeout=_PROGRESS_BAR_UPDATE_INTERVAL) diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index eccc46a7a..be67eafcd 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -2433,7 +2433,7 @@ def test_to_dataframe_progress_bar( self.assertEqual(len(df), 4) @unittest.skipIf(pandas is None, "Requires `pandas`") - @mock.patch("google.cloud.bigquery.table.tqdm", new=None) + @mock.patch("google.cloud.bigquery._tqdm_helpers.tqdm", new=None) def test_to_dataframe_no_tqdm_no_progress_bar(self): from google.cloud.bigquery.schema import SchemaField @@ -2461,7 +2461,7 @@ def test_to_dataframe_no_tqdm_no_progress_bar(self): self.assertEqual(len(df), 4) @unittest.skipIf(pandas is None, "Requires `pandas`") - @mock.patch("google.cloud.bigquery.table.tqdm", new=None) + @mock.patch("google.cloud.bigquery._tqdm_helpers.tqdm", new=None) def test_to_dataframe_no_tqdm(self): from google.cloud.bigquery.schema import SchemaField