From ba02f248ba9c449c34859579a4011f4bfd2f4a93 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Wed, 1 Sep 2021 14:22:16 -0600 Subject: [PATCH] feat: set the X-Server-Timeout header when timeout is set (#927) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) Fixes #919 🦕 --- google/cloud/bigquery/client.py | 27 ++++++++++++++++++++++++++- tests/unit/conftest.py | 19 +++++++++++++++++++ tests/unit/test_client.py | 27 +++++++++++++++++++-------- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 023346ffa..47ff83c5d 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -131,6 +131,8 @@ # https://github.com/googleapis/python-bigquery/issues/781#issuecomment-883497414 _PYARROW_BAD_VERSIONS = frozenset([packaging.version.Version("2.0.0")]) +TIMEOUT_HEADER = "X-Server-Timeout" + class Project(object): """Wrapper for resource describing a BigQuery project. @@ -742,16 +744,26 @@ def create_table( return self.get_table(table.reference, retry=retry) def _call_api( - self, retry, span_name=None, span_attributes=None, job_ref=None, **kwargs + self, + retry, + span_name=None, + span_attributes=None, + job_ref=None, + headers: Optional[Dict[str, str]] = None, + **kwargs, ): + kwargs = _add_server_timeout_header(headers, kwargs) call = functools.partial(self._connection.api_request, **kwargs) + if retry: call = retry(call) + if span_name is not None: with create_span( name=span_name, attributes=span_attributes, client=self, job_ref=job_ref ): return call() + return call() def get_dataset( @@ -4045,3 +4057,16 @@ def _get_upload_headers(user_agent): "User-Agent": user_agent, "content-type": "application/json", } + + +def _add_server_timeout_header(headers: Optional[Dict[str, str]], kwargs): + timeout = kwargs.get("timeout") + if timeout is not None: + if headers is None: + headers = {} + headers[TIMEOUT_HEADER] = str(timeout) + + if headers: + kwargs["headers"] = headers + + return kwargs diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7a67ea6b5..feba65aa5 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import mock import pytest from .helpers import make_client @@ -35,3 +36,21 @@ def DS_ID(): @pytest.fixture def LOCATION(): yield "us-central" + + +def noop_add_server_timeout_header(headers, kwargs): + if headers: + kwargs["headers"] = headers + return kwargs + + +@pytest.fixture(autouse=True) +def disable_add_server_timeout_header(request): + if "enable_add_server_timeout_header" in request.keywords: + yield + else: + with mock.patch( + "google.cloud.bigquery.client._add_server_timeout_header", + noop_add_server_timeout_header, + ): + yield diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index e9204f1de..d2a75413f 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1806,7 +1806,6 @@ def test_update_dataset(self): "access": ACCESS, }, path="/" + PATH, - headers=None, timeout=7.5, ) self.assertEqual(ds2.description, ds.description) @@ -1850,7 +1849,6 @@ def test_update_dataset_w_custom_property(self): method="PATCH", data={"newAlphaProperty": "unreleased property"}, path=path, - headers=None, timeout=DEFAULT_TIMEOUT, ) @@ -1909,7 +1907,7 @@ def test_update_model(self): "labels": {"x": "y"}, } conn.api_request.assert_called_once_with( - method="PATCH", data=sent, path="/" + path, headers=None, timeout=7.5 + method="PATCH", data=sent, path="/" + path, timeout=7.5 ) self.assertEqual(updated_model.model_id, model.model_id) self.assertEqual(updated_model.description, model.description) @@ -1982,7 +1980,6 @@ def test_update_routine(self): method="PUT", data=sent, path="/projects/routines-project/datasets/test_routines/routines/updated_routine", - headers=None, timeout=7.5, ) self.assertEqual(actual_routine.arguments, routine.arguments) @@ -2090,7 +2087,7 @@ def test_update_table(self): "labels": {"x": "y"}, } conn.api_request.assert_called_once_with( - method="PATCH", data=sent, path="/" + path, headers=None, timeout=7.5 + method="PATCH", data=sent, path="/" + path, timeout=7.5 ) self.assertEqual(updated_table.description, table.description) self.assertEqual(updated_table.friendly_name, table.friendly_name) @@ -2140,7 +2137,6 @@ def test_update_table_w_custom_property(self): method="PATCH", path="/%s" % path, data={"newAlphaProperty": "unreleased property"}, - headers=None, timeout=DEFAULT_TIMEOUT, ) self.assertEqual( @@ -2175,7 +2171,6 @@ def test_update_table_only_use_legacy_sql(self): method="PATCH", path="/%s" % path, data={"view": {"useLegacySql": True}}, - headers=None, timeout=DEFAULT_TIMEOUT, ) self.assertEqual(updated_table.view_use_legacy_sql, table.view_use_legacy_sql) @@ -2273,7 +2268,6 @@ def test_update_table_w_query(self): "expirationTime": str(_millis(exp_time)), "schema": schema_resource, }, - headers=None, timeout=DEFAULT_TIMEOUT, ) @@ -8173,3 +8167,20 @@ def transmit_next_chunk(transport): chunk_size = RU.call_args_list[0][0][1] assert chunk_size == 100 * (1 << 20) + + +@pytest.mark.enable_add_server_timeout_header +@pytest.mark.parametrize("headers", [None, {}]) +def test__call_api_add_server_timeout_w_timeout(client, headers): + client._connection = make_connection({}) + client._call_api(None, method="GET", path="/", headers=headers, timeout=42) + client._connection.api_request.assert_called_with( + method="GET", path="/", timeout=42, headers={"X-Server-Timeout": "42"} + ) + + +@pytest.mark.enable_add_server_timeout_header +def test__call_api_no_add_server_timeout_wo_timeout(client): + client._connection = make_connection({}) + client._call_api(None, method="GET", path="/") + client._connection.api_request.assert_called_with(method="GET", path="/")