Skip to content

Commit

Permalink
Merge branch 'master' into iss-151
Browse files Browse the repository at this point in the history
  • Loading branch information
plamut committed Jul 20, 2020
2 parents f579d17 + 9c3409b commit 6e28998
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 16 deletions.
12 changes: 12 additions & 0 deletions google/cloud/bigquery/_pandas_helpers.py
Expand Up @@ -19,6 +19,7 @@
import logging
import warnings

import six
from six.moves import queue

try:
Expand Down Expand Up @@ -780,3 +781,14 @@ def download_dataframe_bqstorage(
selected_fields=selected_fields,
page_to_item=page_to_item,
)


def dataframe_to_json_generator(dataframe):
for row in dataframe.itertuples(index=False, name=None):
output = {}
for column, value in six.moves.zip(dataframe.columns, row):
# Omit NaN values.
if value != value:
continue
output[column] = value
yield output
21 changes: 16 additions & 5 deletions google/cloud/bigquery/client.py
Expand Up @@ -441,6 +441,10 @@ def create_dataset(
google.cloud.bigquery.dataset.Dataset:
A new ``Dataset`` returned from the API.
Raises:
google.cloud.exceptions.Conflict:
If the dataset already exists.
Example:
>>> from google.cloud import bigquery
Expand Down Expand Up @@ -496,6 +500,10 @@ def create_routine(
Returns:
google.cloud.bigquery.routine.Routine:
A new ``Routine`` returned from the service.
Raises:
google.cloud.exceptions.Conflict:
If the routine already exists.
"""
reference = routine.reference
path = "/projects/{}/datasets/{}/routines".format(
Expand Down Expand Up @@ -540,6 +548,10 @@ def create_table(self, table, exists_ok=False, retry=DEFAULT_RETRY, timeout=None
Returns:
google.cloud.bigquery.table.Table:
A new ``Table`` returned from the service.
Raises:
google.cloud.exceptions.Conflict:
If the table already exists.
"""
table = _table_arg_to_table(table, default_project=self.project)

Expand Down Expand Up @@ -2535,7 +2547,9 @@ def insert_rows_from_dataframe(
]):
The destination table for the row data, or a reference to it.
dataframe (pandas.DataFrame):
A :class:`~pandas.DataFrame` containing the data to load.
A :class:`~pandas.DataFrame` containing the data to load. Any
``NaN`` values present in the dataframe are omitted from the
streaming API request(s).
selected_fields (Sequence[google.cloud.bigquery.schema.SchemaField]):
The fields to return. Required if ``table`` is a
:class:`~google.cloud.bigquery.table.TableReference`.
Expand All @@ -2559,10 +2573,7 @@ def insert_rows_from_dataframe(
insert_results = []

chunk_count = int(math.ceil(len(dataframe) / chunk_size))
rows_iter = (
dict(six.moves.zip(dataframe.columns, row))
for row in dataframe.itertuples(index=False, name=None)
)
rows_iter = _pandas_helpers.dataframe_to_json_generator(dataframe)

for _ in range(chunk_count):
rows_chunk = itertools.islice(rows_iter, chunk_size)
Expand Down
20 changes: 20 additions & 0 deletions google/cloud/bigquery/table.py
Expand Up @@ -1891,10 +1891,20 @@ def interval(self, value):
def _key(self):
return tuple(sorted(self._properties.items()))

def __eq__(self, other):
if not isinstance(other, PartitionRange):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __repr__(self):
key_vals = ["{}={}".format(key, val) for key, val in self._key()]
return "PartitionRange({})".format(", ".join(key_vals))

__hash__ = None


class RangePartitioning(object):
"""Range-based partitioning configuration for a table.
Expand Down Expand Up @@ -1961,10 +1971,20 @@ def field(self, value):
def _key(self):
return (("field", self.field), ("range_", self.range_))

def __eq__(self, other):
if not isinstance(other, RangePartitioning):
return NotImplemented
return self._key() == other._key()

def __ne__(self, other):
return not self == other

def __repr__(self):
key_vals = ["{}={}".format(key, repr(val)) for key, val in self._key()]
return "RangePartitioning({})".format(", ".join(key_vals))

__hash__ = None


class TimePartitioningType(object):
"""Specifies the type of time partitioning to perform."""
Expand Down
8 changes: 1 addition & 7 deletions setup.py
Expand Up @@ -30,16 +30,10 @@
release_status = "Development Status :: 5 - Production/Stable"
dependencies = [
'enum34; python_version < "3.4"',
"google-auth >= 1.9.0, < 2.0dev",
"google-api-core >= 1.15.0, < 2.0dev",
"google-api-core >= 1.21.0, < 2.0dev",
"google-cloud-core >= 1.1.0, < 2.0dev",
"google-resumable-media >= 0.5.0, < 0.6dev",
"protobuf >= 3.6.0",
"six >=1.13.0,< 2.0.0dev",
# rsa >= 4.1 is not compatible with Python 2
# https://github.com/sybrenstuvel/python-rsa/issues/152
'rsa <4.1; python_version < "3"',
'rsa >=3.1.4, <5; python_version >= "3"',
]
extras = {
"bqstorage": [
Expand Down
30 changes: 26 additions & 4 deletions tests/system.py
Expand Up @@ -2335,6 +2335,14 @@ def test_insert_rows_from_dataframe(self):
"string_col": "another string",
"int_col": 50,
},
{
"float_col": 6.66,
"bool_col": True,
# Include a NaN value, because pandas often uses NaN as a
# NULL value indicator.
"string_col": float("NaN"),
"int_col": 60,
},
]
)

Expand All @@ -2344,14 +2352,28 @@ def test_insert_rows_from_dataframe(self):
table = retry_403(Config.CLIENT.create_table)(table_arg)
self.to_delete.insert(0, table)

Config.CLIENT.insert_rows_from_dataframe(table, dataframe, chunk_size=3)
chunk_errors = Config.CLIENT.insert_rows_from_dataframe(
table, dataframe, chunk_size=3
)
for errors in chunk_errors:
assert not errors

retry = RetryResult(_has_rows, max_tries=8)
rows = retry(self._fetch_single_page)(table)
# Use query to fetch rows instead of listing directly from the table so
# that we get values from the streaming buffer.
rows = list(
Config.CLIENT.query(
"SELECT * FROM `{}.{}.{}`".format(
table.project, table.dataset_id, table.table_id
)
)
)

sorted_rows = sorted(rows, key=operator.attrgetter("int_col"))
row_tuples = [r.values() for r in sorted_rows]
expected = [tuple(data_row) for data_row in dataframe.itertuples(index=False)]
expected = [
tuple(None if col != col else col for col in data_row)
for data_row in dataframe.itertuples(index=False)
]

assert len(row_tuples) == len(expected)

Expand Down
68 changes: 68 additions & 0 deletions tests/unit/test_client.py
Expand Up @@ -5582,6 +5582,74 @@ def test_insert_rows_from_dataframe(self):
)
assert call == expected_call

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_insert_rows_from_dataframe_nan(self):
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery.table import Table

API_PATH = "/projects/{}/datasets/{}/tables/{}/insertAll".format(
self.PROJECT, self.DS_ID, self.TABLE_REF.table_id
)

dataframe = pandas.DataFrame(
{
"str_col": ["abc", "def", float("NaN"), "jkl"],
"int_col": [1, float("NaN"), 3, 4],
"float_col": [float("NaN"), 0.25, 0.5, 0.125],
}
)

# create client
creds = _make_credentials()
http = object()
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
conn = client._connection = make_connection({}, {})

# create table
schema = [
SchemaField("str_col", "STRING"),
SchemaField("int_col", "INTEGER"),
SchemaField("float_col", "FLOAT"),
]
table = Table(self.TABLE_REF, schema=schema)

with mock.patch("uuid.uuid4", side_effect=map(str, range(len(dataframe)))):
error_info = client.insert_rows_from_dataframe(
table, dataframe, chunk_size=3, timeout=7.5
)

self.assertEqual(len(error_info), 2)
for chunk_errors in error_info:
assert chunk_errors == []

EXPECTED_SENT_DATA = [
{
"rows": [
{"insertId": "0", "json": {"str_col": "abc", "int_col": 1}},
{"insertId": "1", "json": {"str_col": "def", "float_col": 0.25}},
{"insertId": "2", "json": {"int_col": 3, "float_col": 0.5}},
]
},
{
"rows": [
{
"insertId": "3",
"json": {"str_col": "jkl", "int_col": 4, "float_col": 0.125},
}
]
},
]

actual_calls = conn.api_request.call_args_list

for call, expected_data in six.moves.zip_longest(
actual_calls, EXPECTED_SENT_DATA
):
expected_call = mock.call(
method="POST", path=API_PATH, data=expected_data, timeout=7.5
)
assert call == expected_call

@unittest.skipIf(pandas is None, "Requires `pandas`")
def test_insert_rows_from_dataframe_many_columns(self):
from google.cloud.bigquery.schema import SchemaField
Expand Down
82 changes: 82 additions & 0 deletions tests/unit/test_table.py
Expand Up @@ -3525,6 +3525,37 @@ def test_constructor_w_resource(self):
assert object_under_test.end == 1234567890
assert object_under_test.interval == 1000000

def test___eq___start_mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=2, end=10, interval=2)
self.assertNotEqual(object_under_test, other)

def test___eq___end__mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=11, interval=2)
self.assertNotEqual(object_under_test, other)

def test___eq___interval__mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=11, interval=3)
self.assertNotEqual(object_under_test, other)

def test___eq___hit(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
other = self._make_one(start=1, end=10, interval=2)
self.assertEqual(object_under_test, other)

def test__eq___type_mismatch(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
self.assertNotEqual(object_under_test, object())
self.assertEqual(object_under_test, mock.ANY)

def test_unhashable_object(self):
object_under_test1 = self._make_one(start=1, end=10, interval=2)

with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"):
hash(object_under_test1)

def test_repr(self):
object_under_test = self._make_one(start=1, end=10, interval=2)
assert repr(object_under_test) == "PartitionRange(end=10, interval=2, start=1)"
Expand Down Expand Up @@ -3574,6 +3605,57 @@ def test_range_w_wrong_type(self):
with pytest.raises(ValueError, match="PartitionRange"):
object_under_test.range_ = object()

def test___eq___field_mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="float_col"
)
self.assertNotEqual(object_under_test, other)

def test___eq___range__mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=2, end=20, interval=2), field="float_col"
)
self.assertNotEqual(object_under_test, other)

def test___eq___hit(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
other = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
self.assertEqual(object_under_test, other)

def test__eq___type_mismatch(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
self.assertNotEqual(object_under_test, object())
self.assertEqual(object_under_test, mock.ANY)

def test_unhashable_object(self):
from google.cloud.bigquery.table import PartitionRange

object_under_test1 = self._make_one(
range_=PartitionRange(start=1, end=10, interval=2), field="integer_col"
)
with six.assertRaisesRegex(self, TypeError, r".*unhashable type.*"):
hash(object_under_test1)

def test_repr(self):
from google.cloud.bigquery.table import PartitionRange

Expand Down

0 comments on commit 6e28998

Please sign in to comment.