diff --git a/google/cloud/bigquery/_tqdm_helpers.py b/google/cloud/bigquery/_tqdm_helpers.py index bdecefe4a..2fcf2a981 100644 --- a/google/cloud/bigquery/_tqdm_helpers.py +++ b/google/cloud/bigquery/_tqdm_helpers.py @@ -55,15 +55,14 @@ def get_progress_bar(progress_bar_type, description, total, unit): def wait_for_query(query_job, progress_bar_type=None): """Return query result and display a progress bar while the query running, if tqdm is installed.""" - if progress_bar_type is None: - return query_job.result() - default_total = 1 current_stage = None start_time = time.time() progress_bar = get_progress_bar( progress_bar_type, "Query is running", default_total, "query" ) + if progress_bar is None: + return query_job.result() i = 0 while True: if query_job.query_plan: diff --git a/google/cloud/bigquery/magics/magics.py b/google/cloud/bigquery/magics/magics.py index 5645a84a5..f04a6364a 100644 --- a/google/cloud/bigquery/magics/magics.py +++ b/google/cloud/bigquery/magics/magics.py @@ -182,6 +182,7 @@ def __init__(self): self._default_query_job_config = bigquery.QueryJobConfig() self._bigquery_client_options = client_options.ClientOptions() self._bqstorage_client_options = client_options.ClientOptions() + self._progress_bar_type = "tqdm" @property def credentials(self): @@ -313,6 +314,26 @@ def default_query_job_config(self): def default_query_job_config(self, value): self._default_query_job_config = value + @property + def progress_bar_type(self): + """str: Default progress bar type to use to display progress bar while + executing queries through IPython magics. + + Note:: + Install the ``tqdm`` package to use this feature. + + Example: + Manually setting the progress_bar_type: + + >>> from google.cloud.bigquery import magics + >>> magics.context.progress_bar_type = "tqdm" + """ + return self._progress_bar_type + + @progress_bar_type.setter + def progress_bar_type(self, value): + self._progress_bar_type = value + context = Context() @@ -524,6 +545,15 @@ def _create_dataset_if_necessary(client, dataset_id): "name (ex. $my_dict_var)." ), ) +@magic_arguments.argument( + "--progress_bar_type", + type=str, + default=None, + help=( + "Sets progress bar type to display a progress bar while executing the query." + "Defaults to use tqdm. Install the ``tqdm`` package to use this feature." + ), +) def _cell_magic(line, query): """Underlying function for bigquery cell magic @@ -687,12 +717,16 @@ def _cell_magic(line, query): ) return query_job + progress_bar = context.progress_bar_type or args.progress_bar_type + if max_results: result = query_job.result(max_results=max_results).to_dataframe( - bqstorage_client=bqstorage_client + bqstorage_client=bqstorage_client, progress_bar_type=progress_bar ) else: - result = query_job.to_dataframe(bqstorage_client=bqstorage_client) + result = query_job.to_dataframe( + bqstorage_client=bqstorage_client, progress_bar_type=progress_bar + ) if args.destination_var: IPython.get_ipython().push({args.destination_var: result}) diff --git a/tests/unit/test_magics.py b/tests/unit/test_magics.py index a7cf92919..ff41fe720 100644 --- a/tests/unit/test_magics.py +++ b/tests/unit/test_magics.py @@ -623,7 +623,7 @@ def warning_match(warning): assert client_info.user_agent == "ipython-" + IPython.__version__ query_job_mock.to_dataframe.assert_called_once_with( - bqstorage_client=bqstorage_instance_mock + bqstorage_client=bqstorage_instance_mock, progress_bar_type="tqdm" ) assert isinstance(return_value, pandas.DataFrame) @@ -665,7 +665,9 @@ def test_bigquery_magic_with_rest_client_requested(monkeypatch): return_value = ip.run_cell_magic("bigquery", "--use_rest_api", sql) bqstorage_mock.assert_not_called() - query_job_mock.to_dataframe.assert_called_once_with(bqstorage_client=None) + query_job_mock.to_dataframe.assert_called_once_with( + bqstorage_client=None, progress_bar_type="tqdm" + ) assert isinstance(return_value, pandas.DataFrame) @@ -1167,6 +1169,71 @@ def test_bigquery_magic_w_maximum_bytes_billed_w_context_setter(): assert sent_config["maximumBytesBilled"] == "10203" +@pytest.mark.usefixtures("ipython_interactive") +@pytest.mark.skipif(pandas is None, reason="Requires `pandas`") +def test_bigquery_magic_w_progress_bar_type_w_context_setter(monkeypatch): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context._project = None + + magics.context.progress_bar_type = "tqdm_gui" + + mock_credentials = mock.create_autospec( + google.auth.credentials.Credentials, instance=True + ) + + # Set up the context with monkeypatch so that it's reset for subsequent + # tests. + monkeypatch.setattr(magics.context, "_credentials", mock_credentials) + + # Mock out the BigQuery Storage API. + bqstorage_mock = mock.create_autospec(bigquery_storage.BigQueryReadClient) + bqstorage_client_patch = mock.patch( + "google.cloud.bigquery_storage.BigQueryReadClient", bqstorage_mock + ) + + sql = "SELECT 17 AS num" + result = pandas.DataFrame([17], columns=["num"]) + run_query_patch = mock.patch( + "google.cloud.bigquery.magics.magics._run_query", autospec=True + ) + query_job_mock = mock.create_autospec( + google.cloud.bigquery.job.QueryJob, instance=True + ) + query_job_mock.to_dataframe.return_value = result + with run_query_patch as run_query_mock, bqstorage_client_patch: + run_query_mock.return_value = query_job_mock + + return_value = ip.run_cell_magic("bigquery", "--use_rest_api", sql) + + bqstorage_mock.assert_not_called() + query_job_mock.to_dataframe.assert_called_once_with( + bqstorage_client=None, progress_bar_type=magics.context.progress_bar_type + ) + + assert isinstance(return_value, pandas.DataFrame) + + +@pytest.mark.usefixtures("ipython_interactive") +def test_bigquery_magic_with_progress_bar_type(): + ip = IPython.get_ipython() + ip.extension_manager.load_extension("google.cloud.bigquery") + magics.context.progress_bar_type = None + + run_query_patch = mock.patch( + "google.cloud.bigquery.magics.magics._run_query", autospec=True + ) + with run_query_patch as run_query_mock: + ip.run_cell_magic( + "bigquery", "--progress_bar_type=tqdm_gui", "SELECT 17 as num" + ) + + progress_bar_used = run_query_mock.mock_calls[1][2]["progress_bar_type"] + assert progress_bar_used == "tqdm_gui" + # context progress bar type should not change + assert magics.context.progress_bar_type is None + + @pytest.mark.usefixtures("ipython_interactive") def test_bigquery_magic_with_project(): ip = IPython.get_ipython()