Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
perf: cache first page of jobs.getQueryResults rows (#374)
Co-authored-by: Steffany Brown <30247553+steffnay@users.noreply.github.com>
  • Loading branch information
tswast and steffnay committed Nov 10, 2020
1 parent cd9febd commit 86f6a51
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 60 deletions.
4 changes: 3 additions & 1 deletion google/cloud/bigquery/client.py
Expand Up @@ -1534,7 +1534,7 @@ def _get_query_results(
A new ``_QueryResults`` instance.
"""

extra_params = {"maxResults": 0}
extra_params = {}

if project is None:
project = self.project
Expand Down Expand Up @@ -3187,6 +3187,7 @@ def _list_rows_from_query_results(
page_size=None,
retry=DEFAULT_RETRY,
timeout=None,
first_page_response=None,
):
"""List the rows of a completed query.
See
Expand Down Expand Up @@ -3247,6 +3248,7 @@ def _list_rows_from_query_results(
table=destination,
extra_params=params,
total_rows=total_rows,
first_page_response=first_page_response,
)
return row_iterator

Expand Down
85 changes: 54 additions & 31 deletions google/cloud/bigquery/job/query.py
Expand Up @@ -990,48 +990,22 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True):
Returns:
bool: True if the job is complete, False otherwise.
"""
is_done = (
# Only consider a QueryJob complete when we know we have the final
# query results available.
self._query_results is not None
and self._query_results.complete
and self.state == _DONE_STATE
)
# Do not refresh if the state is already done, as the job will not
# change once complete.
is_done = self.state == _DONE_STATE
if not reload or is_done:
return is_done

# Since the API to getQueryResults can hang up to the timeout value
# (default of 10 seconds), set the timeout parameter to ensure that
# the timeout from the futures API is respected. See:
# https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135
timeout_ms = None
if self._done_timeout is not None:
# Subtract a buffer for context switching, network latency, etc.
api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS
api_timeout = max(min(api_timeout, 10), 0)
self._done_timeout -= api_timeout
self._done_timeout = max(0, self._done_timeout)
timeout_ms = int(api_timeout * 1000)
self._reload_query_results(retry=retry, timeout=timeout)

# If an explicit timeout is not given, fall back to the transport timeout
# stored in _blocking_poll() in the process of polling for job completion.
transport_timeout = timeout if timeout is not None else self._transport_timeout

self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
timeout_ms=timeout_ms,
location=self.location,
timeout=transport_timeout,
)

# Only reload the job once we know the query is complete.
# This will ensure that fields such as the destination table are
# correctly populated.
if self._query_results.complete and self.state != _DONE_STATE:
if self._query_results.complete:
self.reload(retry=retry, timeout=transport_timeout)

return self.state == _DONE_STATE
Expand Down Expand Up @@ -1098,6 +1072,45 @@ def _begin(self, client=None, retry=DEFAULT_RETRY, timeout=None):
exc.query_job = self
raise

def _reload_query_results(self, retry=DEFAULT_RETRY, timeout=None):
"""Refresh the cached query results.
Args:
retry (Optional[google.api_core.retry.Retry]):
How to retry the call that retrieves query results.
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
"""
if self._query_results and self._query_results.complete:
return

# Since the API to getQueryResults can hang up to the timeout value
# (default of 10 seconds), set the timeout parameter to ensure that
# the timeout from the futures API is respected. See:
# https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135
timeout_ms = None
if self._done_timeout is not None:
# Subtract a buffer for context switching, network latency, etc.
api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS
api_timeout = max(min(api_timeout, 10), 0)
self._done_timeout -= api_timeout
self._done_timeout = max(0, self._done_timeout)
timeout_ms = int(api_timeout * 1000)

# If an explicit timeout is not given, fall back to the transport timeout
# stored in _blocking_poll() in the process of polling for job completion.
transport_timeout = timeout if timeout is not None else self._transport_timeout

self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
timeout_ms=timeout_ms,
location=self.location,
timeout=transport_timeout,
)

def result(
self,
page_size=None,
Expand Down Expand Up @@ -1144,6 +1157,11 @@ def result(
"""
try:
super(QueryJob, self).result(retry=retry, timeout=timeout)

# Since the job could already be "done" (e.g. got a finished job
# via client.get_job), the superclass call to done() might not
# set the self._query_results cache.
self._reload_query_results(retry=retry, timeout=timeout)
except exceptions.GoogleAPICallError as exc:
exc.message += self._format_for_exception(self.query, self.job_id)
exc.query_job = self
Expand All @@ -1158,10 +1176,14 @@ def result(
if self._query_results.total_rows is None:
return _EmptyRowIterator()

first_page_response = None
if max_results is None and page_size is None and start_index is None:
first_page_response = self._query_results._properties

rows = self._client._list_rows_from_query_results(
self._query_results.job_id,
self.job_id,
self.location,
self._query_results.project,
self.project,
self._query_results.schema,
total_rows=self._query_results.total_rows,
destination=self.destination,
Expand All @@ -1170,6 +1192,7 @@ def result(
start_index=start_index,
retry=retry,
timeout=timeout,
first_page_response=first_page_response,
)
rows._preserve_order = _contains_order_by(self.query)
return rows
Expand Down
11 changes: 10 additions & 1 deletion google/cloud/bigquery/table.py
Expand Up @@ -1308,7 +1308,9 @@ class RowIterator(HTTPIterator):
A subset of columns to select from this table.
total_rows (Optional[int]):
Total number of rows in the table.
first_page_response (Optional[dict]):
API response for the first page of results. These are returned when
the first page is requested.
"""

def __init__(
Expand All @@ -1324,6 +1326,7 @@ def __init__(
table=None,
selected_fields=None,
total_rows=None,
first_page_response=None,
):
super(RowIterator, self).__init__(
client,
Expand All @@ -1346,6 +1349,7 @@ def __init__(
self._selected_fields = selected_fields
self._table = table
self._total_rows = total_rows
self._first_page_response = first_page_response

def _get_next_page_response(self):
"""Requests the next page from the path provided.
Expand All @@ -1354,6 +1358,11 @@ def _get_next_page_response(self):
Dict[str, object]:
The parsed JSON response of the next page's contents.
"""
if self._first_page_response:
response = self._first_page_response
self._first_page_response = None
return response

params = self._get_query_params()
if self._page_size is not None:
if self.page_number and "startIndex" in params:
Expand Down
55 changes: 42 additions & 13 deletions tests/unit/job/test_query.py
Expand Up @@ -787,7 +787,9 @@ def test_result(self):
"location": "EU",
},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "2",
"totalRows": "3",
"rows": [{"f": [{"v": "abc"}]}],
"pageToken": "next-page",
}
job_resource = self._make_resource(started=True, location="EU")
job_resource_done = self._make_resource(started=True, ended=True, location="EU")
Expand All @@ -799,9 +801,9 @@ def test_result(self):
query_page_resource = {
# Explicitly set totalRows to be different from the initial
# response to test update during iteration.
"totalRows": "1",
"totalRows": "2",
"pageToken": None,
"rows": [{"f": [{"v": "abc"}]}],
"rows": [{"f": [{"v": "def"}]}],
}
conn = _make_connection(
query_resource, query_resource_done, job_resource_done, query_page_resource
Expand All @@ -812,19 +814,20 @@ def test_result(self):
result = job.result()

self.assertIsInstance(result, RowIterator)
self.assertEqual(result.total_rows, 2)
self.assertEqual(result.total_rows, 3)
rows = list(result)
self.assertEqual(len(rows), 1)
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(rows[1].col1, "def")
# Test that the total_rows property has changed during iteration, based
# on the response from tabledata.list.
self.assertEqual(result.total_rows, 1)
self.assertEqual(result.total_rows, 2)

query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
query_results_call = mock.call(
method="GET",
path=query_results_path,
query_params={"maxResults": 0, "location": "EU"},
query_params={"location": "EU"},
timeout=None,
)
reload_call = mock.call(
Expand All @@ -839,6 +842,7 @@ def test_result(self):
query_params={
"fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS,
"location": "EU",
"pageToken": "next-page",
},
timeout=None,
)
Expand All @@ -851,7 +855,9 @@ def test_result_with_done_job_calls_get_query_results(self):
"jobComplete": True,
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "1",
"totalRows": "2",
"rows": [{"f": [{"v": "abc"}]}],
"pageToken": "next-page",
}
job_resource = self._make_resource(started=True, ended=True, location="EU")
job_resource["configuration"]["query"]["destinationTable"] = {
Expand All @@ -860,9 +866,9 @@ def test_result_with_done_job_calls_get_query_results(self):
"tableId": "dest_table",
}
results_page_resource = {
"totalRows": "1",
"totalRows": "2",
"pageToken": None,
"rows": [{"f": [{"v": "abc"}]}],
"rows": [{"f": [{"v": "def"}]}],
}
conn = _make_connection(query_resource_done, results_page_resource)
client = _make_client(self.PROJECT, connection=conn)
Expand All @@ -871,14 +877,15 @@ def test_result_with_done_job_calls_get_query_results(self):
result = job.result()

rows = list(result)
self.assertEqual(len(rows), 1)
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(rows[1].col1, "def")

query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
query_results_call = mock.call(
method="GET",
path=query_results_path,
query_params={"maxResults": 0, "location": "EU"},
query_params={"location": "EU"},
timeout=None,
)
query_results_page_call = mock.call(
Expand All @@ -887,6 +894,7 @@ def test_result_with_done_job_calls_get_query_results(self):
query_params={
"fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS,
"location": "EU",
"pageToken": "next-page",
},
timeout=None,
)
Expand All @@ -900,6 +908,12 @@ def test_result_with_max_results(self):
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "5",
# These rows are discarded because max_results is set.
"rows": [
{"f": [{"v": "xyz"}]},
{"f": [{"v": "uvw"}]},
{"f": [{"v": "rst"}]},
],
}
query_page_resource = {
"totalRows": "5",
Expand All @@ -925,6 +939,7 @@ def test_result_with_max_results(self):
rows = list(result)

self.assertEqual(len(rows), 3)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(len(connection.api_request.call_args_list), 2)
query_page_request = connection.api_request.call_args_list[1]
self.assertEqual(
Expand Down Expand Up @@ -979,7 +994,7 @@ def test_result_w_retry(self):
query_results_call = mock.call(
method="GET",
path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}",
query_params={"maxResults": 0, "location": "asia-northeast1"},
query_params={"location": "asia-northeast1"},
timeout=None,
)
reload_call = mock.call(
Expand Down Expand Up @@ -1079,6 +1094,12 @@ def test_result_w_page_size(self):
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "4",
# These rows are discarded because page_size is set.
"rows": [
{"f": [{"v": "xyz"}]},
{"f": [{"v": "uvw"}]},
{"f": [{"v": "rst"}]},
],
}
job_resource = self._make_resource(started=True, ended=True, location="US")
q_config = job_resource["configuration"]["query"]
Expand Down Expand Up @@ -1109,6 +1130,7 @@ def test_result_w_page_size(self):
# Assert
actual_rows = list(result)
self.assertEqual(len(actual_rows), 4)
self.assertEqual(actual_rows[0].col1, "row1")

query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
query_page_1_call = mock.call(
Expand Down Expand Up @@ -1142,6 +1164,12 @@ def test_result_with_start_index(self):
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "5",
# These rows are discarded because start_index is set.
"rows": [
{"f": [{"v": "xyz"}]},
{"f": [{"v": "uvw"}]},
{"f": [{"v": "rst"}]},
],
}
tabledata_resource = {
"totalRows": "5",
Expand All @@ -1168,6 +1196,7 @@ def test_result_with_start_index(self):
rows = list(result)

self.assertEqual(len(rows), 4)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(len(connection.api_request.call_args_list), 2)
tabledata_list_request = connection.api_request.call_args_list[1]
self.assertEqual(
Expand Down

0 comments on commit 86f6a51

Please sign in to comment.