Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: cache first page of jobs.getQueryResults rows #374

Merged
merged 2 commits into from Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion google/cloud/bigquery/client.py
Expand Up @@ -1534,7 +1534,7 @@ def _get_query_results(
A new ``_QueryResults`` instance.
"""

extra_params = {"maxResults": 0}
extra_params = {}

if project is None:
project = self.project
Expand Down Expand Up @@ -3187,6 +3187,7 @@ def _list_rows_from_query_results(
page_size=None,
retry=DEFAULT_RETRY,
timeout=None,
first_page_response=None,
):
"""List the rows of a completed query.
See
Expand Down Expand Up @@ -3247,6 +3248,7 @@ def _list_rows_from_query_results(
table=destination,
extra_params=params,
total_rows=total_rows,
first_page_response=first_page_response,
)
return row_iterator

Expand Down
85 changes: 54 additions & 31 deletions google/cloud/bigquery/job/query.py
Expand Up @@ -990,48 +990,22 @@ def done(self, retry=DEFAULT_RETRY, timeout=None, reload=True):
Returns:
bool: True if the job is complete, False otherwise.
"""
is_done = (
# Only consider a QueryJob complete when we know we have the final
# query results available.
self._query_results is not None
and self._query_results.complete
and self.state == _DONE_STATE
)
# Do not refresh if the state is already done, as the job will not
# change once complete.
is_done = self.state == _DONE_STATE
if not reload or is_done:
return is_done

# Since the API to getQueryResults can hang up to the timeout value
# (default of 10 seconds), set the timeout parameter to ensure that
# the timeout from the futures API is respected. See:
# https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135
timeout_ms = None
if self._done_timeout is not None:
# Subtract a buffer for context switching, network latency, etc.
api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS
api_timeout = max(min(api_timeout, 10), 0)
self._done_timeout -= api_timeout
self._done_timeout = max(0, self._done_timeout)
timeout_ms = int(api_timeout * 1000)
self._reload_query_results(retry=retry, timeout=timeout)

# If an explicit timeout is not given, fall back to the transport timeout
# stored in _blocking_poll() in the process of polling for job completion.
transport_timeout = timeout if timeout is not None else self._transport_timeout

self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
timeout_ms=timeout_ms,
location=self.location,
timeout=transport_timeout,
)

# Only reload the job once we know the query is complete.
# This will ensure that fields such as the destination table are
# correctly populated.
if self._query_results.complete and self.state != _DONE_STATE:
if self._query_results.complete:
self.reload(retry=retry, timeout=transport_timeout)

return self.state == _DONE_STATE
Expand Down Expand Up @@ -1098,6 +1072,45 @@ def _begin(self, client=None, retry=DEFAULT_RETRY, timeout=None):
exc.query_job = self
raise

def _reload_query_results(self, retry=DEFAULT_RETRY, timeout=None):
"""Refresh the cached query results.

Args:
retry (Optional[google.api_core.retry.Retry]):
How to retry the call that retrieves query results.
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
"""
if self._query_results and self._query_results.complete:
return

# Since the API to getQueryResults can hang up to the timeout value
# (default of 10 seconds), set the timeout parameter to ensure that
# the timeout from the futures API is respected. See:
# https://github.com/GoogleCloudPlatform/google-cloud-python/issues/4135
timeout_ms = None
if self._done_timeout is not None:
# Subtract a buffer for context switching, network latency, etc.
api_timeout = self._done_timeout - _TIMEOUT_BUFFER_SECS
api_timeout = max(min(api_timeout, 10), 0)
self._done_timeout -= api_timeout
self._done_timeout = max(0, self._done_timeout)
timeout_ms = int(api_timeout * 1000)

# If an explicit timeout is not given, fall back to the transport timeout
# stored in _blocking_poll() in the process of polling for job completion.
transport_timeout = timeout if timeout is not None else self._transport_timeout

self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
timeout_ms=timeout_ms,
location=self.location,
timeout=transport_timeout,
)

def result(
self,
page_size=None,
Expand Down Expand Up @@ -1144,6 +1157,11 @@ def result(
"""
try:
super(QueryJob, self).result(retry=retry, timeout=timeout)

# Since the job could already be "done" (e.g. got a finished job
# via client.get_job), the superclass call to done() might not
# set the self._query_results cache.
self._reload_query_results(retry=retry, timeout=timeout)
except exceptions.GoogleAPICallError as exc:
exc.message += self._format_for_exception(self.query, self.job_id)
exc.query_job = self
Expand All @@ -1158,10 +1176,14 @@ def result(
if self._query_results.total_rows is None:
return _EmptyRowIterator()

first_page_response = None
if max_results is None and page_size is None and start_index is None:
first_page_response = self._query_results._properties

rows = self._client._list_rows_from_query_results(
self._query_results.job_id,
self.job_id,
self.location,
self._query_results.project,
self.project,
self._query_results.schema,
total_rows=self._query_results.total_rows,
destination=self.destination,
Expand All @@ -1170,6 +1192,7 @@ def result(
start_index=start_index,
retry=retry,
timeout=timeout,
first_page_response=first_page_response,
)
rows._preserve_order = _contains_order_by(self.query)
return rows
Expand Down
11 changes: 10 additions & 1 deletion google/cloud/bigquery/table.py
Expand Up @@ -1308,7 +1308,9 @@ class RowIterator(HTTPIterator):
A subset of columns to select from this table.
total_rows (Optional[int]):
Total number of rows in the table.

first_page_response (Optional[dict]):
API response for the first page of results. These are returned when
the first page is requested.
"""

def __init__(
Expand All @@ -1324,6 +1326,7 @@ def __init__(
table=None,
selected_fields=None,
total_rows=None,
first_page_response=None,
):
super(RowIterator, self).__init__(
client,
Expand All @@ -1346,6 +1349,7 @@ def __init__(
self._selected_fields = selected_fields
self._table = table
self._total_rows = total_rows
self._first_page_response = first_page_response

def _get_next_page_response(self):
"""Requests the next page from the path provided.
Expand All @@ -1354,6 +1358,11 @@ def _get_next_page_response(self):
Dict[str, object]:
The parsed JSON response of the next page's contents.
"""
if self._first_page_response:
response = self._first_page_response
self._first_page_response = None
return response

params = self._get_query_params()
if self._page_size is not None:
if self.page_number and "startIndex" in params:
Expand Down
55 changes: 42 additions & 13 deletions tests/unit/job/test_query.py
Expand Up @@ -787,7 +787,9 @@ def test_result(self):
"location": "EU",
},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "2",
"totalRows": "3",
"rows": [{"f": [{"v": "abc"}]}],
"pageToken": "next-page",
}
job_resource = self._make_resource(started=True, location="EU")
job_resource_done = self._make_resource(started=True, ended=True, location="EU")
Expand All @@ -799,9 +801,9 @@ def test_result(self):
query_page_resource = {
# Explicitly set totalRows to be different from the initial
# response to test update during iteration.
"totalRows": "1",
"totalRows": "2",
"pageToken": None,
"rows": [{"f": [{"v": "abc"}]}],
"rows": [{"f": [{"v": "def"}]}],
}
conn = _make_connection(
query_resource, query_resource_done, job_resource_done, query_page_resource
Expand All @@ -812,19 +814,20 @@ def test_result(self):
result = job.result()

self.assertIsInstance(result, RowIterator)
self.assertEqual(result.total_rows, 2)
self.assertEqual(result.total_rows, 3)
rows = list(result)
self.assertEqual(len(rows), 1)
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(rows[1].col1, "def")
# Test that the total_rows property has changed during iteration, based
# on the response from tabledata.list.
self.assertEqual(result.total_rows, 1)
self.assertEqual(result.total_rows, 2)

query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
query_results_call = mock.call(
method="GET",
path=query_results_path,
query_params={"maxResults": 0, "location": "EU"},
query_params={"location": "EU"},
timeout=None,
)
reload_call = mock.call(
Expand All @@ -839,6 +842,7 @@ def test_result(self):
query_params={
"fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS,
"location": "EU",
"pageToken": "next-page",
},
timeout=None,
)
Expand All @@ -851,7 +855,9 @@ def test_result_with_done_job_calls_get_query_results(self):
"jobComplete": True,
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "1",
"totalRows": "2",
"rows": [{"f": [{"v": "abc"}]}],
"pageToken": "next-page",
}
job_resource = self._make_resource(started=True, ended=True, location="EU")
job_resource["configuration"]["query"]["destinationTable"] = {
Expand All @@ -860,9 +866,9 @@ def test_result_with_done_job_calls_get_query_results(self):
"tableId": "dest_table",
}
results_page_resource = {
"totalRows": "1",
"totalRows": "2",
"pageToken": None,
"rows": [{"f": [{"v": "abc"}]}],
"rows": [{"f": [{"v": "def"}]}],
}
conn = _make_connection(query_resource_done, results_page_resource)
client = _make_client(self.PROJECT, connection=conn)
Expand All @@ -871,14 +877,15 @@ def test_result_with_done_job_calls_get_query_results(self):
result = job.result()

rows = list(result)
self.assertEqual(len(rows), 1)
self.assertEqual(len(rows), 2)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(rows[1].col1, "def")

query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
query_results_call = mock.call(
method="GET",
path=query_results_path,
query_params={"maxResults": 0, "location": "EU"},
query_params={"location": "EU"},
timeout=None,
)
query_results_page_call = mock.call(
Expand All @@ -887,6 +894,7 @@ def test_result_with_done_job_calls_get_query_results(self):
query_params={
"fields": _LIST_ROWS_FROM_QUERY_RESULTS_FIELDS,
"location": "EU",
"pageToken": "next-page",
},
timeout=None,
)
Expand All @@ -900,6 +908,12 @@ def test_result_with_max_results(self):
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "5",
# These rows are discarded because max_results is set.
"rows": [
{"f": [{"v": "xyz"}]},
{"f": [{"v": "uvw"}]},
{"f": [{"v": "rst"}]},
],
}
query_page_resource = {
"totalRows": "5",
Expand All @@ -925,6 +939,7 @@ def test_result_with_max_results(self):
rows = list(result)

self.assertEqual(len(rows), 3)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(len(connection.api_request.call_args_list), 2)
query_page_request = connection.api_request.call_args_list[1]
self.assertEqual(
Expand Down Expand Up @@ -979,7 +994,7 @@ def test_result_w_retry(self):
query_results_call = mock.call(
method="GET",
path=f"/projects/{self.PROJECT}/queries/{self.JOB_ID}",
query_params={"maxResults": 0, "location": "asia-northeast1"},
query_params={"location": "asia-northeast1"},
timeout=None,
)
reload_call = mock.call(
Expand Down Expand Up @@ -1079,6 +1094,12 @@ def test_result_w_page_size(self):
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "4",
# These rows are discarded because page_size is set.
"rows": [
{"f": [{"v": "xyz"}]},
{"f": [{"v": "uvw"}]},
{"f": [{"v": "rst"}]},
],
}
job_resource = self._make_resource(started=True, ended=True, location="US")
q_config = job_resource["configuration"]["query"]
Expand Down Expand Up @@ -1109,6 +1130,7 @@ def test_result_w_page_size(self):
# Assert
actual_rows = list(result)
self.assertEqual(len(actual_rows), 4)
self.assertEqual(actual_rows[0].col1, "row1")

query_results_path = f"/projects/{self.PROJECT}/queries/{self.JOB_ID}"
query_page_1_call = mock.call(
Expand Down Expand Up @@ -1142,6 +1164,12 @@ def test_result_with_start_index(self):
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
"totalRows": "5",
# These rows are discarded because start_index is set.
"rows": [
{"f": [{"v": "xyz"}]},
{"f": [{"v": "uvw"}]},
{"f": [{"v": "rst"}]},
],
}
tabledata_resource = {
"totalRows": "5",
Expand All @@ -1168,6 +1196,7 @@ def test_result_with_start_index(self):
rows = list(result)

self.assertEqual(len(rows), 4)
self.assertEqual(rows[0].col1, "abc")
self.assertEqual(len(connection.api_request.call_args_list), 2)
tabledata_list_request = connection.api_request.call_args_list[1]
self.assertEqual(
Expand Down