Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: omit NaN values when uploading from insert_rows_from_dataframe #170

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions google/cloud/bigquery/_pandas_helpers.py
Expand Up @@ -19,6 +19,7 @@
import logging
import warnings

import six
from six.moves import queue

try:
Expand Down Expand Up @@ -780,3 +781,14 @@ def download_dataframe_bqstorage(
selected_fields=selected_fields,
page_to_item=page_to_item,
)


def dataframe_to_json_generator(dataframe):
for row in dataframe.itertuples(index=False, name=None):
output = {}
for column, value in six.moves.zip(dataframe.columns, row):
# Omit NaN values.
if value != value:
continue
output[column] = value
yield output
9 changes: 4 additions & 5 deletions google/cloud/bigquery/client.py
Expand Up @@ -2535,7 +2535,9 @@ def insert_rows_from_dataframe(
]):
The destination table for the row data, or a reference to it.
dataframe (pandas.DataFrame):
A :class:`~pandas.DataFrame` containing the data to load.
A :class:`~pandas.DataFrame` containing the data to load. Any
``NaN`` values present in the dataframe are omitted from the
streaming API request(s).
selected_fields (Sequence[google.cloud.bigquery.schema.SchemaField]):
The fields to return. Required if ``table`` is a
:class:`~google.cloud.bigquery.table.TableReference`.
Expand All @@ -2559,10 +2561,7 @@ def insert_rows_from_dataframe(
insert_results = []

chunk_count = int(math.ceil(len(dataframe) / chunk_size))
rows_iter = (
dict(six.moves.zip(dataframe.columns, row))
for row in dataframe.itertuples(index=False, name=None)
)
rows_iter = _pandas_helpers.dataframe_to_json_generator(dataframe)

for _ in range(chunk_count):
rows_chunk = itertools.islice(rows_iter, chunk_size)
Expand Down
30 changes: 26 additions & 4 deletions tests/system.py
Expand Up @@ -2335,6 +2335,14 @@ def test_insert_rows_from_dataframe(self):
"string_col": "another string",
"int_col": 50,
},
{
"float_col": 6.66,
"bool_col": True,
# Include a NaN value, because pandas often uses NaN as a
# NULL value indicator.
"string_col": float("NaN"),
"int_col": 60,
},
]
)

Expand All @@ -2344,14 +2352,28 @@ def test_insert_rows_from_dataframe(self):
table = retry_403(Config.CLIENT.create_table)(table_arg)
self.to_delete.insert(0, table)

Config.CLIENT.insert_rows_from_dataframe(table, dataframe, chunk_size=3)
chunk_errors = Config.CLIENT.insert_rows_from_dataframe(
table, dataframe, chunk_size=3
)
for errors in chunk_errors:
assert not errors

retry = RetryResult(_has_rows, max_tries=8)
rows = retry(self._fetch_single_page)(table)
# Use query to fetch rows instead of listing directly from the table so
# that we get values from the streaming buffer.
rows = list(
Config.CLIENT.query(
"SELECT * FROM `{}.{}.{}`".format(
table.project, table.dataset_id, table.table_id
)
)
)

sorted_rows = sorted(rows, key=operator.attrgetter("int_col"))
row_tuples = [r.values() for r in sorted_rows]
expected = [tuple(data_row) for data_row in dataframe.itertuples(index=False)]
expected = [
tuple(None if col != col else col for col in data_row)
for data_row in dataframe.itertuples(index=False)
]

assert len(row_tuples) == len(expected)

Expand Down
68 changes: 68 additions & 0 deletions tests/unit/test_client.py
Expand Up @@ -5582,6 +5582,74 @@ def test_insert_rows_from_dataframe(self):
)
assert call == expected_call

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_insert_rows_from_dataframe_nan(self):
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery.table import Table

API_PATH = "/projects/{}/datasets/{}/tables/{}/insertAll".format(
self.PROJECT, self.DS_ID, self.TABLE_REF.table_id
)

dataframe = pandas.DataFrame(
{
"str_col": ["abc", "def", float("NaN"), "jkl"],
"int_col": [1, float("NaN"), 3, 4],
"float_col": [float("NaN"), 0.25, 0.5, 0.125],
}
)

# create client
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection({}, {})

# create table
schema = [
SchemaField("str_col", "STRING"),
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
]
table = Table(self.TABLE_REF, schema=schema)

with mock.patch("uuid.uuid4", side_effect=map(str, range(len(dataframe)))):
error_info = client.insert_rows_from_dataframe(
table, dataframe, chunk_size=3, timeout=7.5
)

self.assertEqual(len(error_info), 2)
for chunk_errors in error_info:
assert chunk_errors == []

EXPECTED_SENT_DATA = [
{
"rows": [
{"insertId": "0", "json": {"str_col": "abc", "int_col": 1}},
{"insertId": "1", "json": {"str_col": "def", "float_col": 0.25}},
{"insertId": "2", "json": {"int_col": 3, "float_col": 0.5}},
]
},
{
"rows": [
{
"insertId": "3",
"json": {"str_col": "jkl", "int_col": 4, "float_col": 0.125},
}
]
},
]

actual_calls = conn.api_request.call_args_list

for call, expected_data in six.moves.zip_longest(
actual_calls, EXPECTED_SENT_DATA
):
Comment on lines +5645 to +5647
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice pattern!

expected_call = mock.call(
method="POST", path=API_PATH, data=expected_data, timeout=7.5
)
assert call == expected_call

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_insert_rows_from_dataframe_many_columns(self):
from google.cloud.bigquery.schema import SchemaField
Expand Down