diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index f8c0d7c93..bdbcb767c 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -3423,6 +3423,12 @@ def schema_to_json(self, schema_list, destination): with open(destination, mode="w") as file_obj: return self._schema_to_json_file_object(json_schema_list, file_obj) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + # pylint: disable=unused-argument def _item_to_project(iterator, resource): diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 66add9c0a..6c3263ea5 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -7218,6 +7218,28 @@ def test_list_rows_error(self): with self.assertRaises(TypeError): client.list_rows(1) + def test_context_manager_enter_returns_itself(self): + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + + with mock.patch.object(client, "close"), client as context_var: + pass + + self.assertIs(client, context_var) + + def test_context_manager_exit_closes_client(self): + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + + fake_close = mock.Mock() + with mock.patch.object(client, "close", fake_close): + with client: + pass + + fake_close.assert_called_once() + class Test_make_job_id(unittest.TestCase): def _call_fut(self, job_id, prefix=None):