Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: avoid extra API calls from to_dataframe if all rows are cached #384

Merged
merged 1 commit into from Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
56 changes: 40 additions & 16 deletions google/cloud/bigquery/table.py
Expand Up @@ -1351,6 +1351,41 @@ def __init__(
self._total_rows = total_rows
self._first_page_response = first_page_response

def _is_completely_cached(self):
"""Check if all results are completely cached.

This is useful to know, because we can avoid alternative download
mechanisms.
"""
if self._first_page_response is None or self.next_page_token:
return False

return self._first_page_response.get(self._next_token) is None

def _validate_bqstorage(self, bqstorage_client, create_bqstorage_client):
"""Returns if the BigQuery Storage API can be used.

Returns:
bool
True if the BigQuery Storage client can be used or created.
"""
using_bqstorage_api = bqstorage_client or create_bqstorage_client
if not using_bqstorage_api:
return False

if self._is_completely_cached():
return False

if self.max_results is not None:
warnings.warn(
"Cannot use bqstorage_client if max_results is set, "
"reverting to fetching data with the REST endpoint.",
stacklevel=2,
)
return False

return True

def _get_next_page_response(self):
"""Requests the next page from the path provided.

Expand Down Expand Up @@ -1412,6 +1447,9 @@ def _get_progress_bar(self, progress_bar_type):
def _to_page_iterable(
self, bqstorage_download, tabledata_list_download, bqstorage_client=None
):
if not self._validate_bqstorage(bqstorage_client, False):
bqstorage_client = None

if bqstorage_client is not None:
for item in bqstorage_download():
yield item
Expand Down Expand Up @@ -1503,14 +1541,7 @@ def to_arrow(
if pyarrow is None:
raise ValueError(_NO_PYARROW_ERROR)

if (
bqstorage_client or create_bqstorage_client
) and self.max_results is not None:
warnings.warn(
"Cannot use bqstorage_client if max_results is set, "
"reverting to fetching data with the REST endpoint.",
stacklevel=2,
)
if not self._validate_bqstorage(bqstorage_client, create_bqstorage_client):
create_bqstorage_client = False
bqstorage_client = None

Expand Down Expand Up @@ -1687,14 +1718,7 @@ def to_dataframe(
if dtypes is None:
dtypes = {}

if (
bqstorage_client or create_bqstorage_client
) and self.max_results is not None:
warnings.warn(
"Cannot use bqstorage_client if max_results is set, "
"reverting to fetching data with the REST endpoint.",
stacklevel=2,
)
if not self._validate_bqstorage(bqstorage_client, create_bqstorage_client):
create_bqstorage_client = False
bqstorage_client = None

Expand Down
28 changes: 24 additions & 4 deletions tests/unit/job/test_query_pandas.py
Expand Up @@ -99,6 +99,7 @@ def test_to_dataframe_bqstorage_preserve_order(query):
]
},
"totalRows": "4",
"pageToken": "next-page",
}
connection = _make_connection(get_query_results_resource, job_resource)
client = _make_client(connection=connection)
Expand Down Expand Up @@ -133,7 +134,16 @@ def test_to_dataframe_bqstorage_preserve_order(query):


@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
def test_to_arrow():
@pytest.mark.parametrize(
"method_kwargs",
[
{"create_bqstorage_client": False},
# Since all rows are contained in the first page of results, the BigQuery
# Storage API won't actually be used.
{"create_bqstorage_client": True},
],
)
def test_to_arrow(method_kwargs):
from google.cloud.bigquery.job import QueryJob as target_class

begun_resource = _make_job_resource(job_type="query")
Expand Down Expand Up @@ -182,7 +192,7 @@ def test_to_arrow():
client = _make_client(connection=connection)
job = target_class.from_api_repr(begun_resource, client)

tbl = job.to_arrow(create_bqstorage_client=False)
tbl = job.to_arrow(**method_kwargs)

assert isinstance(tbl, pyarrow.Table)
assert tbl.num_rows == 2
Expand Down Expand Up @@ -216,7 +226,16 @@ def test_to_arrow():


@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
def test_to_dataframe():
@pytest.mark.parametrize(
"method_kwargs",
[
{"create_bqstorage_client": False},
# Since all rows are contained in the first page of results, the BigQuery
# Storage API won't actually be used.
{"create_bqstorage_client": True},
],
)
def test_to_dataframe(method_kwargs):
from google.cloud.bigquery.job import QueryJob as target_class

begun_resource = _make_job_resource(job_type="query")
Expand All @@ -243,7 +262,7 @@ def test_to_dataframe():
client = _make_client(connection=connection)
job = target_class.from_api_repr(begun_resource, client)

df = job.to_dataframe(create_bqstorage_client=False)
df = job.to_dataframe(**method_kwargs)

assert isinstance(df, pandas.DataFrame)
assert len(df) == 4 # verify the number of rows
Expand Down Expand Up @@ -288,6 +307,7 @@ def test_to_dataframe_bqstorage():
{"name": "age", "type": "INTEGER", "mode": "NULLABLE"},
]
},
"pageToken": "next-page",
}
connection = _make_connection(query_resource)
client = _make_client(connection=connection)
Expand Down