Skip to content

Commit

Permalink
fix: no longer raise a warning in to_dataframe if max_results set (
Browse files Browse the repository at this point in the history
…#815)

That warning should only be used when BQ Storage client is
explicitly passed in to RowIterator methods when max_results
value is also set.
  • Loading branch information
plamut committed Jul 27, 2021
1 parent 3b70891 commit 3c1be14
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 11 deletions.
30 changes: 25 additions & 5 deletions google/cloud/bigquery/table.py
Expand Up @@ -1552,11 +1552,6 @@ def _validate_bqstorage(self, bqstorage_client, create_bqstorage_client):
return False

if self.max_results is not None:
warnings.warn(
"Cannot use bqstorage_client if max_results is set, "
"reverting to fetching data with the REST endpoint.",
stacklevel=2,
)
return False

try:
Expand Down Expand Up @@ -1604,6 +1599,25 @@ def total_rows(self):
"""int: The total number of rows in the table."""
return self._total_rows

def _maybe_warn_max_results(
self, bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"],
):
"""Issue a warning if BQ Storage client is not ``None`` with ``max_results`` set.
This helper method should be used directly in the relevant top-level public
methods, so that the warning is issued for the correct line in user code.
Args:
bqstorage_client:
The BigQuery Storage client intended to use for downloading result rows.
"""
if bqstorage_client is not None and self.max_results is not None:
warnings.warn(
"Cannot use bqstorage_client if max_results is set, "
"reverting to fetching data with the REST endpoint.",
stacklevel=3,
)

def _to_page_iterable(
self, bqstorage_download, tabledata_list_download, bqstorage_client=None
):
Expand Down Expand Up @@ -1700,6 +1714,8 @@ def to_arrow(
if pyarrow is None:
raise ValueError(_NO_PYARROW_ERROR)

self._maybe_warn_max_results(bqstorage_client)

if not self._validate_bqstorage(bqstorage_client, create_bqstorage_client):
create_bqstorage_client = False
bqstorage_client = None
Expand Down Expand Up @@ -1790,6 +1806,8 @@ def to_dataframe_iterable(
if dtypes is None:
dtypes = {}

self._maybe_warn_max_results(bqstorage_client)

column_names = [field.name for field in self._schema]
bqstorage_download = functools.partial(
_pandas_helpers.download_dataframe_bqstorage,
Expand Down Expand Up @@ -1896,6 +1914,8 @@ def to_dataframe(
if dtypes is None:
dtypes = {}

self._maybe_warn_max_results(bqstorage_client)

if not self._validate_bqstorage(bqstorage_client, create_bqstorage_client):
create_bqstorage_client = False
bqstorage_client = None
Expand Down
160 changes: 154 additions & 6 deletions tests/unit/test_table.py
Expand Up @@ -15,6 +15,7 @@
import datetime
import logging
import time
import types
import unittest
import warnings

Expand Down Expand Up @@ -1862,6 +1863,15 @@ def test__validate_bqstorage_returns_false_when_completely_cached(self):
)
)

def test__validate_bqstorage_returns_false_if_max_results_set(self):
iterator = self._make_one(
max_results=10, first_page_response=None # not cached
)
result = iterator._validate_bqstorage(
bqstorage_client=None, create_bqstorage_client=True
)
self.assertFalse(result)

def test__validate_bqstorage_returns_false_if_missing_dependency(self):
iterator = self._make_one(first_page_response=None) # not cached

Expand Down Expand Up @@ -2105,7 +2115,7 @@ def test_to_arrow_w_empty_table(self):
@unittest.skipIf(
bigquery_storage is None, "Requires `google-cloud-bigquery-storage`"
)
def test_to_arrow_max_results_w_create_bqstorage_warning(self):
def test_to_arrow_max_results_w_explicit_bqstorage_client_warning(self):
from google.cloud.bigquery.schema import SchemaField

schema = [
Expand All @@ -2119,6 +2129,7 @@ def test_to_arrow_max_results_w_create_bqstorage_warning(self):
path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
mock_client = _mock_client()
mock_bqstorage_client = mock.sentinel.bq_storage_client

row_iterator = self._make_one(
client=mock_client,
Expand All @@ -2129,7 +2140,7 @@ def test_to_arrow_max_results_w_create_bqstorage_warning(self):
)

with warnings.catch_warnings(record=True) as warned:
row_iterator.to_arrow(create_bqstorage_client=True)
row_iterator.to_arrow(bqstorage_client=mock_bqstorage_client)

matches = [
warning
Expand All @@ -2139,6 +2150,49 @@ def test_to_arrow_max_results_w_create_bqstorage_warning(self):
and "REST" in str(warning)
]
self.assertEqual(len(matches), 1, msg="User warning was not emitted.")
self.assertIn(
__file__, str(matches[0]), msg="Warning emitted with incorrect stacklevel"
)
mock_client._ensure_bqstorage_client.assert_not_called()

@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
@unittest.skipIf(
bigquery_storage is None, "Requires `google-cloud-bigquery-storage`"
)
def test_to_arrow_max_results_w_create_bqstorage_client_no_warning(self):
from google.cloud.bigquery.schema import SchemaField

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
SchemaField("age", "INTEGER", mode="REQUIRED"),
]
rows = [
{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]},
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
]
path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
mock_client = _mock_client()

row_iterator = self._make_one(
client=mock_client,
api_request=api_request,
path=path,
schema=schema,
max_results=42,
)

with warnings.catch_warnings(record=True) as warned:
row_iterator.to_arrow(create_bqstorage_client=True)

matches = [
warning
for warning in warned
if warning.category is UserWarning
and "cannot use bqstorage_client" in str(warning).lower()
and "REST" in str(warning)
]
self.assertFalse(matches)
mock_client._ensure_bqstorage_client.assert_not_called()

@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
Expand Down Expand Up @@ -2372,7 +2426,6 @@ def test_to_arrow_w_pyarrow_none(self):
@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_to_dataframe_iterable(self):
from google.cloud.bigquery.schema import SchemaField
import types

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
Expand Down Expand Up @@ -2415,7 +2468,6 @@ def test_to_dataframe_iterable(self):
@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_to_dataframe_iterable_with_dtypes(self):
from google.cloud.bigquery.schema import SchemaField
import types

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
Expand Down Expand Up @@ -2527,6 +2579,61 @@ def test_to_dataframe_iterable_w_bqstorage(self):
# Don't close the client if it was passed in.
bqstorage_client._transport.grpc_channel.close.assert_not_called()

@unittest.skipIf(pandas is None, "Requires `pandas`")
@unittest.skipIf(
bigquery_storage is None, "Requires `google-cloud-bigquery-storage`"
)
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_to_dataframe_iterable_w_bqstorage_max_results_warning(self):
from google.cloud.bigquery import schema
from google.cloud.bigquery import table as mut

bqstorage_client = mock.create_autospec(bigquery_storage.BigQueryReadClient)

iterator_schema = [
schema.SchemaField("name", "STRING", mode="REQUIRED"),
schema.SchemaField("age", "INTEGER", mode="REQUIRED"),
]
path = "/foo"
api_request = mock.Mock(
side_effect=[
{
"rows": [{"f": [{"v": "Bengt"}, {"v": "32"}]}],
"pageToken": "NEXTPAGE",
},
{"rows": [{"f": [{"v": "Sven"}, {"v": "33"}]}]},
]
)
row_iterator = mut.RowIterator(
_mock_client(),
api_request,
path,
iterator_schema,
table=mut.TableReference.from_string("proj.dset.tbl"),
selected_fields=iterator_schema,
max_results=25,
)

with warnings.catch_warnings(record=True) as warned:
dfs = row_iterator.to_dataframe_iterable(bqstorage_client=bqstorage_client)

# Was a warning emitted?
matches = [
warning
for warning in warned
if warning.category is UserWarning
and "cannot use bqstorage_client" in str(warning).lower()
and "REST" in str(warning)
]
assert len(matches) == 1, "User warning was not emitted."
assert __file__ in str(matches[0]), "Warning emitted with incorrect stacklevel"

# Basic check of what we got as a result.
dataframes = list(dfs)
assert len(dataframes) == 2
assert isinstance(dataframes[0], pandas.DataFrame)
assert isinstance(dataframes[1], pandas.DataFrame)

@mock.patch("google.cloud.bigquery.table.pandas", new=None)
def test_to_dataframe_iterable_error_if_pandas_is_none(self):
from google.cloud.bigquery.schema import SchemaField
Expand Down Expand Up @@ -2926,7 +3033,7 @@ def test_to_dataframe_max_results_w_bqstorage_warning(self):
self.assertEqual(len(matches), 1, msg="User warning was not emitted.")

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_to_dataframe_max_results_w_create_bqstorage_warning(self):
def test_to_dataframe_max_results_w_explicit_bqstorage_client_warning(self):
from google.cloud.bigquery.schema import SchemaField

schema = [
Expand All @@ -2940,6 +3047,7 @@ def test_to_dataframe_max_results_w_create_bqstorage_warning(self):
path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
mock_client = _mock_client()
mock_bqstorage_client = mock.sentinel.bq_storage_client

row_iterator = self._make_one(
client=mock_client,
Expand All @@ -2950,7 +3058,7 @@ def test_to_dataframe_max_results_w_create_bqstorage_warning(self):
)

with warnings.catch_warnings(record=True) as warned:
row_iterator.to_dataframe(create_bqstorage_client=True)
row_iterator.to_dataframe(bqstorage_client=mock_bqstorage_client)

matches = [
warning
Expand All @@ -2960,6 +3068,46 @@ def test_to_dataframe_max_results_w_create_bqstorage_warning(self):
and "REST" in str(warning)
]
self.assertEqual(len(matches), 1, msg="User warning was not emitted.")
self.assertIn(
__file__, str(matches[0]), msg="Warning emitted with incorrect stacklevel"
)
mock_client._ensure_bqstorage_client.assert_not_called()

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_to_dataframe_max_results_w_create_bqstorage_client_no_warning(self):
from google.cloud.bigquery.schema import SchemaField

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
SchemaField("age", "INTEGER", mode="REQUIRED"),
]
rows = [
{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]},
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
]
path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
mock_client = _mock_client()

row_iterator = self._make_one(
client=mock_client,
api_request=api_request,
path=path,
schema=schema,
max_results=42,
)

with warnings.catch_warnings(record=True) as warned:
row_iterator.to_dataframe(create_bqstorage_client=True)

matches = [
warning
for warning in warned
if warning.category is UserWarning
and "cannot use bqstorage_client" in str(warning).lower()
and "REST" in str(warning)
]
self.assertFalse(matches)
mock_client._ensure_bqstorage_client.assert_not_called()

@unittest.skipIf(pandas is None, "Requires `pandas`")
Expand Down

0 comments on commit 3c1be14

Please sign in to comment.