From 6160fee4b1a79b0ea9031cc18caf6322fe4c4084 Mon Sep 17 00:00:00 2001 From: HemangChothani <50404902+HemangChothani@users.noreply.github.com> Date: Fri, 18 Sep 2020 19:41:28 +0530 Subject: [PATCH] fix: validate job_config.source_format in load_table_from_dataframe (#262) * fix: address job_congig.source_format * fix: nit --- google/cloud/bigquery/client.py | 10 ++++- tests/unit/test_client.py | 78 ++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index 86275487b..d2aa45999 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -2174,7 +2174,15 @@ def load_table_from_dataframe( else: job_config = job.LoadJobConfig() - job_config.source_format = job.SourceFormat.PARQUET + if job_config.source_format: + if job_config.source_format != job.SourceFormat.PARQUET: + raise ValueError( + "Got unexpected source_format: '{}'. Currently, only PARQUET is supported".format( + job_config.source_format + ) + ) + else: + job_config.source_format = job.SourceFormat.PARQUET if location is None: location = self.location diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d354735a1..00bc47017 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -7544,7 +7544,7 @@ def test_load_table_from_dataframe_w_client_location(self): @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") - def test_load_table_from_dataframe_w_custom_job_config(self): + def test_load_table_from_dataframe_w_custom_job_config_wihtout_source_format(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job from google.cloud.bigquery.schema import SchemaField @@ -7553,7 +7553,7 @@ def test_load_table_from_dataframe_w_custom_job_config(self): records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] dataframe = pandas.DataFrame(records) job_config = job.LoadJobConfig( - write_disposition=job.WriteDisposition.WRITE_TRUNCATE + write_disposition=job.WriteDisposition.WRITE_TRUNCATE, ) original_config_copy = copy.deepcopy(job_config) @@ -7595,6 +7595,80 @@ def test_load_table_from_dataframe_w_custom_job_config(self): # the original config object should not have been modified assert job_config.to_api_repr() == original_config_copy.to_api_repr() + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_custom_job_config_w_source_format(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] + dataframe = pandas.DataFrame(records) + job_config = job.LoadJobConfig( + write_disposition=job.WriteDisposition.WRITE_TRUNCATE, + source_format=job.SourceFormat.PARQUET, + ) + original_config_copy = copy.deepcopy(job_config) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")] + ), + ) + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + with load_patch as load_table_from_file, get_table_patch as get_table: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION + ) + + # no need to fetch and inspect table schema for WRITE_TRUNCATE jobs + assert not get_table.called + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.source_format == job.SourceFormat.PARQUET + assert sent_config.write_disposition == job.WriteDisposition.WRITE_TRUNCATE + + # the original config object should not have been modified + assert job_config.to_api_repr() == original_config_copy.to_api_repr() + + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_custom_job_config_w_wrong_source_format(self): + from google.cloud.bigquery import job + + client = self._make_client() + records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] + dataframe = pandas.DataFrame(records) + job_config = job.LoadJobConfig( + write_disposition=job.WriteDisposition.WRITE_TRUNCATE, + source_format=job.SourceFormat.ORC, + ) + + with pytest.raises(ValueError) as exc: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION + ) + + assert "Got unexpected source_format:" in str(exc.value) + @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_w_automatic_schema(self):