Skip to content

Commit

Permalink
feat: promote RowIterator.to_arrow_iterable to public method (#1073)
Browse files Browse the repository at this point in the history
* feat: promote `to_arrow_iterable` to public method

* use correct version number

* Update google/cloud/bigquery/table.py

Co-authored-by: Tim Swast <swast@google.com>
  • Loading branch information
judahrand and tswast committed Nov 19, 2021
1 parent 1b5dc5c commit 21cd710
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 4 deletions.
8 changes: 7 additions & 1 deletion google/cloud/bigquery/_pandas_helpers.py
Expand Up @@ -838,7 +838,12 @@ def _download_table_bqstorage(


def download_arrow_bqstorage(
project_id, table, bqstorage_client, preserve_order=False, selected_fields=None,
project_id,
table,
bqstorage_client,
preserve_order=False,
selected_fields=None,
max_queue_size=_MAX_QUEUE_SIZE_DEFAULT,
):
return _download_table_bqstorage(
project_id,
Expand All @@ -847,6 +852,7 @@ def download_arrow_bqstorage(
preserve_order=preserve_order,
selected_fields=selected_fields,
page_to_item=_bqstorage_page_to_arrow,
max_queue_size=max_queue_size,
)


Expand Down
75 changes: 72 additions & 3 deletions google/cloud/bigquery/table.py
Expand Up @@ -1629,15 +1629,57 @@ def _to_page_iterable(
)
yield from result_pages

def _to_arrow_iterable(self, bqstorage_client=None):
"""Create an iterable of arrow RecordBatches, to process the table as a stream."""
def to_arrow_iterable(
self,
bqstorage_client: "bigquery_storage.BigQueryReadClient" = None,
max_queue_size: int = _pandas_helpers._MAX_QUEUE_SIZE_DEFAULT, # type: ignore
) -> Iterator["pyarrow.RecordBatch"]:
"""[Beta] Create an iterable of class:`pyarrow.RecordBatch`, to process the table as a stream.
Args:
bqstorage_client (Optional[google.cloud.bigquery_storage_v1.BigQueryReadClient]):
A BigQuery Storage API client. If supplied, use the faster
BigQuery Storage API to fetch rows from BigQuery.
This method requires the ``pyarrow`` and
``google-cloud-bigquery-storage`` libraries.
This method only exposes a subset of the capabilities of the
BigQuery Storage API. For full access to all features
(projections, filters, snapshots) use the Storage API directly.
max_queue_size (Optional[int]):
The maximum number of result pages to hold in the internal queue when
streaming query results over the BigQuery Storage API. Ignored if
Storage API is not used.
By default, the max queue size is set to the number of BQ Storage streams
created by the server. If ``max_queue_size`` is :data:`None`, the queue
size is infinite.
Returns:
pyarrow.RecordBatch:
A generator of :class:`~pyarrow.RecordBatch`.
Raises:
ValueError:
If the :mod:`pyarrow` library cannot be imported.
.. versionadded:: 2.31.0
"""
if pyarrow is None:
raise ValueError(_NO_PYARROW_ERROR)

self._maybe_warn_max_results(bqstorage_client)

bqstorage_download = functools.partial(
_pandas_helpers.download_arrow_bqstorage,
self._project,
self._table,
bqstorage_client,
preserve_order=self._preserve_order,
selected_fields=self._selected_fields,
max_queue_size=max_queue_size,
)
tabledata_list_download = functools.partial(
_pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema
Expand Down Expand Up @@ -1729,7 +1771,7 @@ def to_arrow(
)

record_batches = []
for record_batch in self._to_arrow_iterable(
for record_batch in self.to_arrow_iterable(
bqstorage_client=bqstorage_client
):
record_batches.append(record_batch)
Expand Down Expand Up @@ -2225,6 +2267,33 @@ def to_dataframe_iterable(
raise ValueError(_NO_PANDAS_ERROR)
return iter((pandas.DataFrame(),))

def to_arrow_iterable(
self,
bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None,
max_queue_size: Optional[int] = None,
) -> Iterator["pyarrow.RecordBatch"]:
"""Create an iterable of pandas DataFrames, to process the table as a stream.
.. versionadded:: 2.31.0
Args:
bqstorage_client:
Ignored. Added for compatibility with RowIterator.
max_queue_size:
Ignored. Added for compatibility with RowIterator.
Returns:
An iterator yielding a single empty :class:`~pyarrow.RecordBatch`.
Raises:
ValueError:
If the :mod:`pyarrow` library cannot be imported.
"""
if pyarrow is None:
raise ValueError(_NO_PYARROW_ERROR)
return iter((pyarrow.record_batch([]),))

def __iter__(self):
return iter(())

Expand Down
218 changes: 218 additions & 0 deletions tests/unit/test_table.py
Expand Up @@ -1840,6 +1840,25 @@ def test_to_arrow(self):
self.assertIsInstance(tbl, pyarrow.Table)
self.assertEqual(tbl.num_rows, 0)

@mock.patch("google.cloud.bigquery.table.pyarrow", new=None)
def test_to_arrow_iterable_error_if_pyarrow_is_none(self):
row_iterator = self._make_one()
with self.assertRaises(ValueError):
row_iterator.to_arrow_iterable()

@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_to_arrow_iterable(self):
row_iterator = self._make_one()
arrow_iter = row_iterator.to_arrow_iterable()

result = list(arrow_iter)

self.assertEqual(len(result), 1)
record_batch = result[0]
self.assertIsInstance(record_batch, pyarrow.RecordBatch)
self.assertEqual(record_batch.num_rows, 0)
self.assertEqual(record_batch.num_columns, 0)

@mock.patch("google.cloud.bigquery.table.pandas", new=None)
def test_to_dataframe_error_if_pandas_is_none(self):
row_iterator = self._make_one()
Expand Down Expand Up @@ -2151,6 +2170,205 @@ def test__validate_bqstorage_returns_false_w_warning_if_obsolete_version(self):
]
assert matching_warnings, "Obsolete dependency warning not raised."

@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_to_arrow_iterable(self):
from google.cloud.bigquery.schema import SchemaField

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
SchemaField("age", "INTEGER", mode="REQUIRED"),
SchemaField(
"child",
"RECORD",
mode="REPEATED",
fields=[
SchemaField("name", "STRING", mode="REQUIRED"),
SchemaField("age", "INTEGER", mode="REQUIRED"),
],
),
]
rows = [
{
"f": [
{"v": "Bharney Rhubble"},
{"v": "33"},
{
"v": [
{"v": {"f": [{"v": "Whamm-Whamm Rhubble"}, {"v": "3"}]}},
{"v": {"f": [{"v": "Hoppy"}, {"v": "1"}]}},
]
},
]
},
{
"f": [
{"v": "Wylma Phlyntstone"},
{"v": "29"},
{
"v": [
{"v": {"f": [{"v": "Bepples Phlyntstone"}, {"v": "0"}]}},
{"v": {"f": [{"v": "Dino"}, {"v": "4"}]}},
]
},
]
},
]
path = "/foo"
api_request = mock.Mock(
side_effect=[
{"rows": [rows[0]], "pageToken": "NEXTPAGE"},
{"rows": [rows[1]]},
]
)
row_iterator = self._make_one(
_mock_client(), api_request, path, schema, page_size=1, max_results=5
)

record_batches = row_iterator.to_arrow_iterable()
self.assertIsInstance(record_batches, types.GeneratorType)
record_batches = list(record_batches)
self.assertEqual(len(record_batches), 2)

# Check the schema.
for record_batch in record_batches:
self.assertIsInstance(record_batch, pyarrow.RecordBatch)
self.assertEqual(record_batch.schema[0].name, "name")
self.assertTrue(pyarrow.types.is_string(record_batch.schema[0].type))
self.assertEqual(record_batch.schema[1].name, "age")
self.assertTrue(pyarrow.types.is_int64(record_batch.schema[1].type))
child_field = record_batch.schema[2]
self.assertEqual(child_field.name, "child")
self.assertTrue(pyarrow.types.is_list(child_field.type))
self.assertTrue(pyarrow.types.is_struct(child_field.type.value_type))
self.assertEqual(child_field.type.value_type[0].name, "name")
self.assertEqual(child_field.type.value_type[1].name, "age")

# Check the data.
record_batch_1 = record_batches[0].to_pydict()
names = record_batch_1["name"]
ages = record_batch_1["age"]
children = record_batch_1["child"]
self.assertEqual(names, ["Bharney Rhubble"])
self.assertEqual(ages, [33])
self.assertEqual(
children,
[
[
{"name": "Whamm-Whamm Rhubble", "age": 3},
{"name": "Hoppy", "age": 1},
],
],
)

record_batch_2 = record_batches[1].to_pydict()
names = record_batch_2["name"]
ages = record_batch_2["age"]
children = record_batch_2["child"]
self.assertEqual(names, ["Wylma Phlyntstone"])
self.assertEqual(ages, [29])
self.assertEqual(
children,
[[{"name": "Bepples Phlyntstone", "age": 0}, {"name": "Dino", "age": 4}]],
)

@mock.patch("google.cloud.bigquery.table.pyarrow", new=None)
def test_to_arrow_iterable_error_if_pyarrow_is_none(self):
from google.cloud.bigquery.schema import SchemaField

schema = [
SchemaField("name", "STRING", mode="REQUIRED"),
SchemaField("age", "INTEGER", mode="REQUIRED"),
]
rows = [
{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]},
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
]
path = "/foo"
api_request = mock.Mock(return_value={"rows": rows})
row_iterator = self._make_one(_mock_client(), api_request, path, schema)

with pytest.raises(ValueError, match="pyarrow"):
row_iterator.to_arrow_iterable()

@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
@unittest.skipIf(
bigquery_storage is None, "Requires `google-cloud-bigquery-storage`"
)
def test_to_arrow_iterable_w_bqstorage(self):
from google.cloud.bigquery import schema
from google.cloud.bigquery import table as mut
from google.cloud.bigquery_storage_v1 import reader

bqstorage_client = mock.create_autospec(bigquery_storage.BigQueryReadClient)
bqstorage_client._transport = mock.create_autospec(
big_query_read_grpc_transport.BigQueryReadGrpcTransport
)
streams = [
# Use two streams we want to check frames are read from each stream.
{"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"},
{"name": "/projects/proj/dataset/dset/tables/tbl/streams/5678"},
]
session = bigquery_storage.types.ReadSession(streams=streams)
arrow_schema = pyarrow.schema(
[
pyarrow.field("colA", pyarrow.int64()),
# Not alphabetical to test column order.
pyarrow.field("colC", pyarrow.float64()),
pyarrow.field("colB", pyarrow.string()),
]
)
session.arrow_schema.serialized_schema = arrow_schema.serialize().to_pybytes()
bqstorage_client.create_read_session.return_value = session

mock_rowstream = mock.create_autospec(reader.ReadRowsStream)
bqstorage_client.read_rows.return_value = mock_rowstream

mock_rows = mock.create_autospec(reader.ReadRowsIterable)
mock_rowstream.rows.return_value = mock_rows
page_items = [
pyarrow.array([1, -1]),
pyarrow.array([2.0, 4.0]),
pyarrow.array(["abc", "def"]),
]

expected_record_batch = pyarrow.RecordBatch.from_arrays(
page_items, schema=arrow_schema
)
expected_num_record_batches = 3

mock_page = mock.create_autospec(reader.ReadRowsPage)
mock_page.to_arrow.return_value = expected_record_batch
mock_pages = (mock_page,) * expected_num_record_batches
type(mock_rows).pages = mock.PropertyMock(return_value=mock_pages)

schema = [
schema.SchemaField("colA", "INTEGER"),
schema.SchemaField("colC", "FLOAT"),
schema.SchemaField("colB", "STRING"),
]

row_iterator = mut.RowIterator(
_mock_client(),
None, # api_request: ignored
None, # path: ignored
schema,
table=mut.TableReference.from_string("proj.dset.tbl"),
selected_fields=schema,
)

record_batches = list(
row_iterator.to_arrow_iterable(bqstorage_client=bqstorage_client)
)
total_record_batches = len(streams) * len(mock_pages)
self.assertEqual(len(record_batches), total_record_batches)

for record_batch in record_batches:
# Are the record batches return as expected?
self.assertEqual(record_batch, expected_record_batch)

# Don't close the client if it was passed in.
bqstorage_client._transport.grpc_channel.close.assert_not_called()

@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
def test_to_arrow(self):
from google.cloud.bigquery.schema import SchemaField
Expand Down

0 comments on commit 21cd710

Please sign in to comment.