Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set the X-Server-Timeout header when timeout is set #927

Merged
merged 4 commits into from Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 26 additions & 1 deletion google/cloud/bigquery/client.py
Expand Up @@ -131,6 +131,8 @@
# https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414
_PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")])

TIMEOUT_HEADER = "X-Server-Timeout"


class Project(object):
"""Wrapper for resource describing a BigQuery project.
Expand Down Expand Up @@ -742,16 +744,26 @@ def create_table(
return self.get_table(table.reference, retry=retry)

def _call_api(
self, retry, span_name=None, span_attributes=None, job_ref=None, **kwargs
self,
retry,
span_name=None,
span_attributes=None,
job_ref=None,
headers: Optional[Dict[str, str]] = None,
**kwargs,
):
kwargs = _add_server_timeout_header(headers, kwargs)
call = functools.partial(self._connection.api_request, **kwargs)

if retry:
call = retry(call)

if span_name is not None:
with create_span(
name=span_name, attributes=span_attributes, client=self, job_ref=job_ref
):
return call()

return call()

def get_dataset(
Expand Down Expand Up @@ -4045,3 +4057,16 @@ def _get_upload_headers(user_agent):
"User-Agent": user_agent,
"content-type": "application/json",
}


def _add_server_timeout_header(headers: Optional[Dict[str, str]], kwargs):
timeout = kwargs.get("timeout")
if timeout is not None:
if headers is None:
headers = {}
headers[TIMEOUT_HEADER] = str(timeout)

if headers:
kwargs["headers"] = headers

return kwargs
19 changes: 19 additions & 0 deletions tests/unit/conftest.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import mock
import pytest

from .helpers import make_client
Expand All @@ -35,3 +36,21 @@ def DS_ID():
@pytest.fixture
def LOCATION():
yield "us-central"


def noop_add_server_timeout_header(headers, kwargs):
if headers:
kwargs["headers"] = headers
return kwargs


@pytest.fixture(autouse=True)
def disable_add_server_timeout_header(request):
if "enable_add_server_timeout_header" in request.keywords:
yield
else:
with mock.patch(
"google.cloud.bigquery.client._add_server_timeout_header",
noop_add_server_timeout_header,
):
yield
27 changes: 19 additions & 8 deletions tests/unit/test_client.py
Expand Up @@ -1806,7 +1806,6 @@ def test_update_dataset(self):
"access": ACCESS,
},
path="/" + PATH,
headers=None,
timeout=7.5,
)
self.assertEqual(ds2.description, ds.description)
Expand Down Expand Up @@ -1850,7 +1849,6 @@ def test_update_dataset_w_custom_property(self):
method="PATCH",
data={"newAlphaProperty": "unreleased property"},
path=path,
headers=None,
timeout=DEFAULT_TIMEOUT,
)

Expand Down Expand Up @@ -1909,7 +1907,7 @@ def test_update_model(self):
"labels": {"x": "y"},
}
conn.api_request.assert_called_once_with(
method="PATCH", data=sent, path="/" + path, headers=None, timeout=7.5
method="PATCH", data=sent, path="/" + path, timeout=7.5
)
self.assertEqual(updated_model.model_id, model.model_id)
self.assertEqual(updated_model.description, model.description)
Expand Down Expand Up @@ -1982,7 +1980,6 @@ def test_update_routine(self):
method="PUT",
data=sent,
path="/projects/routines-project/datasets/test_routines/routines/updated_routine",
headers=None,
timeout=7.5,
)
self.assertEqual(actual_routine.arguments, routine.arguments)
Expand Down Expand Up @@ -2090,7 +2087,7 @@ def test_update_table(self):
"labels": {"x": "y"},
}
conn.api_request.assert_called_once_with(
method="PATCH", data=sent, path="/" + path, headers=None, timeout=7.5
method="PATCH", data=sent, path="/" + path, timeout=7.5
)
self.assertEqual(updated_table.description, table.description)
self.assertEqual(updated_table.friendly_name, table.friendly_name)
Expand Down Expand Up @@ -2140,7 +2137,6 @@ def test_update_table_w_custom_property(self):
method="PATCH",
path="/%s" % path,
data={"newAlphaProperty": "unreleased property"},
headers=None,
timeout=DEFAULT_TIMEOUT,
)
self.assertEqual(
Expand Down Expand Up @@ -2175,7 +2171,6 @@ def test_update_table_only_use_legacy_sql(self):
method="PATCH",
path="/%s" % path,
data={"view": {"useLegacySql": True}},
headers=None,
timeout=DEFAULT_TIMEOUT,
)
self.assertEqual(updated_table.view_use_legacy_sql, table.view_use_legacy_sql)
Expand Down Expand Up @@ -2273,7 +2268,6 @@ def test_update_table_w_query(self):
"expirationTime": str(_millis(exp_time)),
"schema": schema_resource,
},
headers=None,
timeout=DEFAULT_TIMEOUT,
)

Expand Down Expand Up @@ -8173,3 +8167,20 @@ def transmit_next_chunk(transport):

chunk_size = RU.call_args_list[0][0][1]
assert chunk_size == 100 * (1 << 20)


@pytest.mark.enable_add_server_timeout_header
@pytest.mark.parametrize("headers", [None, {}])
def test__call_api_add_server_timeout_w_timeout(client, headers):
client._connection = make_connection({})
client._call_api(None, method="GET", path="/", headers=headers, timeout=42)
client._connection.api_request.assert_called_with(
method="GET", path="/", timeout=42, headers={"X-Server-Timeout": "42"}
)


@pytest.mark.enable_add_server_timeout_header
def test__call_api_no_add_server_timeout_wo_timeout(client):
client._connection = make_connection({})
client._call_api(None, method="GET", path="/")
client._connection.api_request.assert_called_with(method="GET", path="/")