Skip to content

Commit

Permalink
fix: distinguish server timeouts from transport timeouts (#43)
Browse files Browse the repository at this point in the history
* fix: distinguish transport and query timeouts

A transport layer timeout is made independent of the query timeout,
i.e. the maximum time to wait for the query to complete.

The query timeout is used by the blocking poll so that the backend
does not block for too long when polling for job completion, but
the transport can have different timeout requirements, and we do
not want it to be raising sometimes unnecessary timeout errors.

* Apply timeout to each of the underlying requests

As job methods do not split the timeout anymore between all requests a
method might make, the Client methods are adjusted in the same way.
  • Loading branch information
plamut committed Mar 9, 2020
1 parent 24f3910 commit a17be5f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 233 deletions.
37 changes: 13 additions & 24 deletions google/cloud/bigquery/client.py
Expand Up @@ -22,7 +22,6 @@
except ImportError: # Python 2.7
import collections as collections_abc

import concurrent.futures
import copy
import functools
import gzip
Expand All @@ -48,7 +47,6 @@
import google.api_core.client_options
import google.api_core.exceptions
from google.api_core import page_iterator
from google.auth.transport.requests import TimeoutGuard
import google.cloud._helpers
from google.cloud import exceptions
from google.cloud.client import ClientWithProject
Expand Down Expand Up @@ -2598,27 +2596,22 @@ def list_partitions(self, table, retry=DEFAULT_RETRY, timeout=None):
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.
Returns:
List[str]:
A list of the partition ids present in the partitioned table
"""
table = _table_arg_to_table_ref(table, default_project=self.project)

with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
meta_table = self.get_table(
TableReference(
DatasetReference(table.project, table.dataset_id),
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
),
retry=retry,
timeout=timeout,
)
timeout = guard.remaining_timeout
meta_table = self.get_table(
TableReference(
DatasetReference(table.project, table.dataset_id),
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
),
retry=retry,
timeout=timeout,
)

subset = [col for col in meta_table.schema if col.name == "partition_id"]
return [
Expand Down Expand Up @@ -2685,8 +2678,8 @@ def list_rows(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.
Returns:
google.cloud.bigquery.table.RowIterator:
Expand All @@ -2711,11 +2704,7 @@ def list_rows(
# No schema, but no selected_fields. Assume the developer wants all
# columns, so get the table resource for them rather than failing.
elif len(schema) == 0:
with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
table = self.get_table(table.reference, retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
table = self.get_table(table.reference, retry=retry, timeout=timeout)
schema = table.schema

params = {}
Expand Down
62 changes: 19 additions & 43 deletions google/cloud/bigquery/job.py
Expand Up @@ -26,7 +26,6 @@
from six.moves import http_client

import google.api_core.future.polling
from google.auth.transport.requests import TimeoutGuard
from google.cloud import exceptions
from google.cloud.exceptions import NotFound
from google.cloud.bigquery.dataset import Dataset
Expand Down Expand Up @@ -55,7 +54,6 @@
_DONE_STATE = "DONE"
_STOPPED_REASON = "stopped"
_TIMEOUT_BUFFER_SECS = 0.1
_SERVER_TIMEOUT_MARGIN_SECS = 1.0
_CONTAINS_ORDER_BY = re.compile(r"ORDER\s+BY", re.IGNORECASE)

_ERROR_REASON_TO_EXCEPTION = {
Expand Down Expand Up @@ -796,8 +794,8 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.
Returns:
_AsyncJob: This instance.
Expand All @@ -809,11 +807,7 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
if the job did not complete in the given timeout.
"""
if self.state is None:
with TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
) as guard:
self._begin(retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
self._begin(retry=retry, timeout=timeout)
# TODO: modify PollingFuture so it can pass a retry argument to done().
return super(_AsyncJob, self).result(timeout=timeout)

Expand Down Expand Up @@ -2602,6 +2596,7 @@ def __init__(self, job_id, query, client, job_config=None):
self._configuration = job_config
self._query_results = None
self._done_timeout = None
self._transport_timeout = None

@property
def allow_large_results(self):
Expand Down Expand Up @@ -3059,19 +3054,9 @@ def done(self, retry=DEFAULT_RETRY, timeout=None):
self._done_timeout = max(0, self._done_timeout)
timeout_ms = int(api_timeout * 1000)

# If the server-side processing timeout (timeout_ms) is specified and
# would be picked as the total request timeout, we want to add a small
# margin to it - we don't want to timeout the connection just as the
# server-side processing might have completed, but instead slightly
# after the server-side deadline.
# However, if `timeout` is specified, and is shorter than the adjusted
# server timeout, the former prevails.
if timeout_ms is not None and timeout_ms > 0:
server_timeout_with_margin = timeout_ms / 1000 + _SERVER_TIMEOUT_MARGIN_SECS
if timeout is not None:
timeout = min(server_timeout_with_margin, timeout)
else:
timeout = server_timeout_with_margin
# If an explicit timeout is not given, fall back to the transport timeout
# stored in _blocking_poll() in the process of polling for job completion.
transport_timeout = timeout if timeout is not None else self._transport_timeout

# Do not refresh if the state is already done, as the job will not
# change once complete.
Expand All @@ -3082,19 +3067,20 @@ def done(self, retry=DEFAULT_RETRY, timeout=None):
project=self.project,
timeout_ms=timeout_ms,
location=self.location,
timeout=timeout,
timeout=transport_timeout,
)

# Only reload the job once we know the query is complete.
# This will ensure that fields such as the destination table are
# correctly populated.
if self._query_results.complete:
self.reload(retry=retry, timeout=timeout)
self.reload(retry=retry, timeout=transport_timeout)

return self.state == _DONE_STATE

def _blocking_poll(self, timeout=None):
self._done_timeout = timeout
self._transport_timeout = timeout
super(QueryJob, self)._blocking_poll(timeout=timeout)

@staticmethod
Expand Down Expand Up @@ -3170,8 +3156,8 @@ def result(
timeout (Optional[float]):
The number of seconds to wait for the underlying HTTP transport
before using ``retry``.
If multiple requests are made under the hood, ``timeout`` is
interpreted as the approximate total time of **all** requests.
If multiple requests are made under the hood, ``timeout``
applies to each individual request.
Returns:
google.cloud.bigquery.table.RowIterator:
Expand All @@ -3189,27 +3175,17 @@ def result(
If the job did not complete in the given timeout.
"""
try:
guard = TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
)
with guard:
super(QueryJob, self).result(retry=retry, timeout=timeout)
timeout = guard.remaining_timeout
super(QueryJob, self).result(retry=retry, timeout=timeout)

# Return an iterator instead of returning the job.
if not self._query_results:
guard = TimeoutGuard(
timeout, timeout_error_type=concurrent.futures.TimeoutError
self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
location=self.location,
timeout=timeout,
)
with guard:
self._query_results = self._client._get_query_results(
self.job_id,
retry,
project=self.project,
location=self.location,
timeout=timeout,
)
timeout = guard.remaining_timeout
except exceptions.GoogleCloudError as exc:
exc.message += self._format_for_exception(self.query, self.job_id)
exc.query_job = self
Expand Down
78 changes: 0 additions & 78 deletions tests/unit/test_client.py
Expand Up @@ -24,7 +24,6 @@
import unittest
import warnings

import freezegun
import mock
import requests
import six
Expand Down Expand Up @@ -5496,43 +5495,6 @@ def test_list_partitions_with_string_id(self):

self.assertEqual(len(partition_list), 0)

def test_list_partitions_splitting_timout_between_requests(self):
from google.cloud.bigquery.table import Table

row_count = 2
meta_info = _make_list_partitons_meta_info(
self.PROJECT, self.DS_ID, self.TABLE_ID, row_count
)

data = {
"totalRows": str(row_count),
"rows": [{"f": [{"v": "20180101"}]}, {"f": [{"v": "20180102"}]}],
}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
client._connection = make_connection(meta_info, data)
table = Table(self.TABLE_REF)

with freezegun.freeze_time("2019-01-01 00:00:00", tick=False) as frozen_time:

def delayed_get_table(*args, **kwargs):
frozen_time.tick(delta=1.4)
return orig_get_table(*args, **kwargs)

orig_get_table = client.get_table
client.get_table = mock.Mock(side_effect=delayed_get_table)

client.list_partitions(table, timeout=5.0)

client.get_table.assert_called_once()
_, kwargs = client.get_table.call_args
self.assertEqual(kwargs.get("timeout"), 5.0)

client._connection.api_request.assert_called()
_, kwargs = client._connection.api_request.call_args
self.assertAlmostEqual(kwargs.get("timeout"), 3.6, places=5)

def test_list_rows(self):
import datetime
from google.cloud._helpers import UTC
Expand Down Expand Up @@ -5918,46 +5880,6 @@ def test_list_rows_with_missing_schema(self):
self.assertEqual(rows[1].age, 31, msg=repr(table))
self.assertIsNone(rows[2].age, msg=repr(table))

def test_list_rows_splitting_timout_between_requests(self):
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery.table import Table

response = {"totalRows": "0", "rows": []}
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
client._connection = make_connection(response, response)

table = Table(
self.TABLE_REF, schema=[SchemaField("field_x", "INTEGER", mode="NULLABLE")]
)

with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen_time:

def delayed_get_table(*args, **kwargs):
frozen_time.tick(delta=1.4)
return table

client.get_table = mock.Mock(side_effect=delayed_get_table)

rows_iter = client.list_rows(
"{}.{}.{}".format(
self.TABLE_REF.project,
self.TABLE_REF.dataset_id,
self.TABLE_REF.table_id,
),
timeout=5.0,
)
six.next(rows_iter.pages)

client.get_table.assert_called_once()
_, kwargs = client.get_table.call_args
self.assertEqual(kwargs.get("timeout"), 5.0)

client._connection.api_request.assert_called_once()
_, kwargs = client._connection.api_request.call_args
self.assertAlmostEqual(kwargs.get("timeout"), 3.6)

def test_list_rows_error(self):
creds = _make_credentials()
http = object()
Expand Down

0 comments on commit a17be5f

Please sign in to comment.