diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index bced246e8..ff6525399 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -19,6 +19,7 @@ import logging import warnings +import six from six.moves import queue try: @@ -780,3 +781,14 @@ def download_dataframe_bqstorage( selected_fields=selected_fields, page_to_item=page_to_item, ) + + +def dataframe_to_json_generator(dataframe): + for row in dataframe.itertuples(index=False, name=None): + output = {} + for column, value in six.moves.zip(dataframe.columns, row): + # Omit NaN values. + if value != value: + continue + output[column] = value + yield output diff --git a/google/cloud/bigquery/client.py b/google/cloud/bigquery/client.py index eceedcd67..20a485698 100644 --- a/google/cloud/bigquery/client.py +++ b/google/cloud/bigquery/client.py @@ -2535,7 +2535,9 @@ def insert_rows_from_dataframe( ]): The destination table for the row data, or a reference to it. dataframe (pandas.DataFrame): - A :class:`~pandas.DataFrame` containing the data to load. + A :class:`~pandas.DataFrame` containing the data to load. Any + ``NaN`` values present in the dataframe are omitted from the + streaming API request(s). selected_fields (Sequence[google.cloud.bigquery.schema.SchemaField]): The fields to return. Required if ``table`` is a :class:`~google.cloud.bigquery.table.TableReference`. @@ -2559,10 +2561,7 @@ def insert_rows_from_dataframe( insert_results = [] chunk_count = int(math.ceil(len(dataframe) / chunk_size)) - rows_iter = ( - dict(six.moves.zip(dataframe.columns, row)) - for row in dataframe.itertuples(index=False, name=None) - ) + rows_iter = _pandas_helpers.dataframe_to_json_generator(dataframe) for _ in range(chunk_count): rows_chunk = itertools.islice(rows_iter, chunk_size) diff --git a/tests/system.py b/tests/system.py index 14d3f49a1..cd5454a87 100644 --- a/tests/system.py +++ b/tests/system.py @@ -2335,6 +2335,14 @@ def test_insert_rows_from_dataframe(self): "string_col": "another string", "int_col": 50, }, + { + "float_col": 6.66, + "bool_col": True, + # Include a NaN value, because pandas often uses NaN as a + # NULL value indicator. + "string_col": float("NaN"), + "int_col": 60, + }, ] ) @@ -2344,14 +2352,28 @@ def test_insert_rows_from_dataframe(self): table = retry_403(Config.CLIENT.create_table)(table_arg) self.to_delete.insert(0, table) - Config.CLIENT.insert_rows_from_dataframe(table, dataframe, chunk_size=3) + chunk_errors = Config.CLIENT.insert_rows_from_dataframe( + table, dataframe, chunk_size=3 + ) + for errors in chunk_errors: + assert not errors - retry = RetryResult(_has_rows, max_tries=8) - rows = retry(self._fetch_single_page)(table) + # Use query to fetch rows instead of listing directly from the table so + # that we get values from the streaming buffer. + rows = list( + Config.CLIENT.query( + "SELECT * FROM `{}.{}.{}`".format( + table.project, table.dataset_id, table.table_id + ) + ) + ) sorted_rows = sorted(rows, key=operator.attrgetter("int_col")) row_tuples = [r.values() for r in sorted_rows] - expected = [tuple(data_row) for data_row in dataframe.itertuples(index=False)] + expected = [ + tuple(None if col != col else col for col in data_row) + for data_row in dataframe.itertuples(index=False) + ] assert len(row_tuples) == len(expected) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0e083d43f..2c4c1342c 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -5582,6 +5582,74 @@ def test_insert_rows_from_dataframe(self): ) assert call == expected_call + @unittest.skipIf(pandas is None, "Requires `pandas`") + def test_insert_rows_from_dataframe_nan(self): + from google.cloud.bigquery.schema import SchemaField + from google.cloud.bigquery.table import Table + + API_PATH = "/projects/{}/datasets/{}/tables/{}/insertAll".format( + self.PROJECT, self.DS_ID, self.TABLE_REF.table_id + ) + + dataframe = pandas.DataFrame( + { + "str_col": ["abc", "def", float("NaN"), "jkl"], + "int_col": [1, float("NaN"), 3, 4], + "float_col": [float("NaN"), 0.25, 0.5, 0.125], + } + ) + + # create client + creds = _make_credentials() + http = object() + client = self._make_one(project=self.PROJECT, credentials=creds, _http=http) + conn = client._connection = make_connection({}, {}) + + # create table + schema = [ + SchemaField("str_col", "STRING"), + SchemaField("int_col", "INTEGER"), + SchemaField("float_col", "FLOAT"), + ] + table = Table(self.TABLE_REF, schema=schema) + + with mock.patch("uuid.uuid4", side_effect=map(str, range(len(dataframe)))): + error_info = client.insert_rows_from_dataframe( + table, dataframe, chunk_size=3, timeout=7.5 + ) + + self.assertEqual(len(error_info), 2) + for chunk_errors in error_info: + assert chunk_errors == [] + + EXPECTED_SENT_DATA = [ + { + "rows": [ + {"insertId": "0", "json": {"str_col": "abc", "int_col": 1}}, + {"insertId": "1", "json": {"str_col": "def", "float_col": 0.25}}, + {"insertId": "2", "json": {"int_col": 3, "float_col": 0.5}}, + ] + }, + { + "rows": [ + { + "insertId": "3", + "json": {"str_col": "jkl", "int_col": 4, "float_col": 0.125}, + } + ] + }, + ] + + actual_calls = conn.api_request.call_args_list + + for call, expected_data in six.moves.zip_longest( + actual_calls, EXPECTED_SENT_DATA + ): + expected_call = mock.call( + method="POST", path=API_PATH, data=expected_data, timeout=7.5 + ) + assert call == expected_call + @unittest.skipIf(pandas is None, "Requires `pandas`") def test_insert_rows_from_dataframe_many_columns(self): from google.cloud.bigquery.schema import SchemaField