diff --git a/google/cloud/bigquery/_pandas_helpers.py b/google/cloud/bigquery/_pandas_helpers.py index de6356c2a..263a1a9cf 100644 --- a/google/cloud/bigquery/_pandas_helpers.py +++ b/google/cloud/bigquery/_pandas_helpers.py @@ -838,7 +838,12 @@ def _download_table_bqstorage( def download_arrow_bqstorage( - project_id, table, bqstorage_client, preserve_order=False, selected_fields=None, + project_id, + table, + bqstorage_client, + preserve_order=False, + selected_fields=None, + max_queue_size=_MAX_QUEUE_SIZE_DEFAULT, ): return _download_table_bqstorage( project_id, @@ -847,6 +852,7 @@ def download_arrow_bqstorage( preserve_order=preserve_order, selected_fields=selected_fields, page_to_item=_bqstorage_page_to_arrow, + max_queue_size=max_queue_size, ) diff --git a/google/cloud/bigquery/table.py b/google/cloud/bigquery/table.py index 60c8593c7..a0696f83f 100644 --- a/google/cloud/bigquery/table.py +++ b/google/cloud/bigquery/table.py @@ -1629,8 +1629,49 @@ def _to_page_iterable( ) yield from result_pages - def _to_arrow_iterable(self, bqstorage_client=None): - """Create an iterable of arrow RecordBatches, to process the table as a stream.""" + def to_arrow_iterable( + self, + bqstorage_client: "bigquery_storage.BigQueryReadClient" = None, + max_queue_size: int = _pandas_helpers._MAX_QUEUE_SIZE_DEFAULT, # type: ignore + ) -> Iterator["pyarrow.RecordBatch"]: + """[Beta] Create an iterable of class:`pyarrow.RecordBatch`, to process the table as a stream. + + Args: + bqstorage_client (Optional[google.cloud.bigquery_storage_v1.BigQueryReadClient]): + A BigQuery Storage API client. If supplied, use the faster + BigQuery Storage API to fetch rows from BigQuery. + + This method requires the ``pyarrow`` and + ``google-cloud-bigquery-storage`` libraries. + + This method only exposes a subset of the capabilities of the + BigQuery Storage API. For full access to all features + (projections, filters, snapshots) use the Storage API directly. + + max_queue_size (Optional[int]): + The maximum number of result pages to hold in the internal queue when + streaming query results over the BigQuery Storage API. Ignored if + Storage API is not used. + + By default, the max queue size is set to the number of BQ Storage streams + created by the server. If ``max_queue_size`` is :data:`None`, the queue + size is infinite. + + Returns: + pyarrow.RecordBatch: + A generator of :class:`~pyarrow.RecordBatch`. + + Raises: + ValueError: + If the :mod:`pyarrow` library cannot be imported. + + .. versionadded:: 2.31.0 + """ + if pyarrow is None: + raise ValueError(_NO_PYARROW_ERROR) + + self._maybe_warn_max_results(bqstorage_client) + bqstorage_download = functools.partial( _pandas_helpers.download_arrow_bqstorage, self._project, @@ -1638,6 +1679,7 @@ def _to_arrow_iterable(self, bqstorage_client=None): bqstorage_client, preserve_order=self._preserve_order, selected_fields=self._selected_fields, + max_queue_size=max_queue_size, ) tabledata_list_download = functools.partial( _pandas_helpers.download_arrow_row_iterator, iter(self.pages), self.schema @@ -1729,7 +1771,7 @@ def to_arrow( ) record_batches = [] - for record_batch in self._to_arrow_iterable( + for record_batch in self.to_arrow_iterable( bqstorage_client=bqstorage_client ): record_batches.append(record_batch) @@ -2225,6 +2267,33 @@ def to_dataframe_iterable( raise ValueError(_NO_PANDAS_ERROR) return iter((pandas.DataFrame(),)) + def to_arrow_iterable( + self, + bqstorage_client: Optional["bigquery_storage.BigQueryReadClient"] = None, + max_queue_size: Optional[int] = None, + ) -> Iterator["pyarrow.RecordBatch"]: + """Create an iterable of pandas DataFrames, to process the table as a stream. + + .. versionadded:: 2.31.0 + + Args: + bqstorage_client: + Ignored. Added for compatibility with RowIterator. + + max_queue_size: + Ignored. Added for compatibility with RowIterator. + + Returns: + An iterator yielding a single empty :class:`~pyarrow.RecordBatch`. + + Raises: + ValueError: + If the :mod:`pyarrow` library cannot be imported. + """ + if pyarrow is None: + raise ValueError(_NO_PYARROW_ERROR) + return iter((pyarrow.record_batch([]),)) + def __iter__(self): return iter(()) diff --git a/tests/unit/test_table.py b/tests/unit/test_table.py index 3c68e3c5e..4f45eac3d 100644 --- a/tests/unit/test_table.py +++ b/tests/unit/test_table.py @@ -1840,6 +1840,25 @@ def test_to_arrow(self): self.assertIsInstance(tbl, pyarrow.Table) self.assertEqual(tbl.num_rows, 0) + @mock.patch("google.cloud.bigquery.table.pyarrow", new=None) + def test_to_arrow_iterable_error_if_pyarrow_is_none(self): + row_iterator = self._make_one() + with self.assertRaises(ValueError): + row_iterator.to_arrow_iterable() + + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_to_arrow_iterable(self): + row_iterator = self._make_one() + arrow_iter = row_iterator.to_arrow_iterable() + + result = list(arrow_iter) + + self.assertEqual(len(result), 1) + record_batch = result[0] + self.assertIsInstance(record_batch, pyarrow.RecordBatch) + self.assertEqual(record_batch.num_rows, 0) + self.assertEqual(record_batch.num_columns, 0) + @mock.patch("google.cloud.bigquery.table.pandas", new=None) def test_to_dataframe_error_if_pandas_is_none(self): row_iterator = self._make_one() @@ -2151,6 +2170,205 @@ def test__validate_bqstorage_returns_false_w_warning_if_obsolete_version(self): ] assert matching_warnings, "Obsolete dependency warning not raised." + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_to_arrow_iterable(self): + from google.cloud.bigquery.schema import SchemaField + + schema = [ + SchemaField("name", "STRING", mode="REQUIRED"), + SchemaField("age", "INTEGER", mode="REQUIRED"), + SchemaField( + "child", + "RECORD", + mode="REPEATED", + fields=[ + SchemaField("name", "STRING", mode="REQUIRED"), + SchemaField("age", "INTEGER", mode="REQUIRED"), + ], + ), + ] + rows = [ + { + "f": [ + {"v": "Bharney Rhubble"}, + {"v": "33"}, + { + "v": [ + {"v": {"f": [{"v": "Whamm-Whamm Rhubble"}, {"v": "3"}]}}, + {"v": {"f": [{"v": "Hoppy"}, {"v": "1"}]}}, + ] + }, + ] + }, + { + "f": [ + {"v": "Wylma Phlyntstone"}, + {"v": "29"}, + { + "v": [ + {"v": {"f": [{"v": "Bepples Phlyntstone"}, {"v": "0"}]}}, + {"v": {"f": [{"v": "Dino"}, {"v": "4"}]}}, + ] + }, + ] + }, + ] + path = "/foo" + api_request = mock.Mock( + side_effect=[ + {"rows": [rows[0]], "pageToken": "NEXTPAGE"}, + {"rows": [rows[1]]}, + ] + ) + row_iterator = self._make_one( + _mock_client(), api_request, path, schema, page_size=1, max_results=5 + ) + + record_batches = row_iterator.to_arrow_iterable() + self.assertIsInstance(record_batches, types.GeneratorType) + record_batches = list(record_batches) + self.assertEqual(len(record_batches), 2) + + # Check the schema. + for record_batch in record_batches: + self.assertIsInstance(record_batch, pyarrow.RecordBatch) + self.assertEqual(record_batch.schema[0].name, "name") + self.assertTrue(pyarrow.types.is_string(record_batch.schema[0].type)) + self.assertEqual(record_batch.schema[1].name, "age") + self.assertTrue(pyarrow.types.is_int64(record_batch.schema[1].type)) + child_field = record_batch.schema[2] + self.assertEqual(child_field.name, "child") + self.assertTrue(pyarrow.types.is_list(child_field.type)) + self.assertTrue(pyarrow.types.is_struct(child_field.type.value_type)) + self.assertEqual(child_field.type.value_type[0].name, "name") + self.assertEqual(child_field.type.value_type[1].name, "age") + + # Check the data. + record_batch_1 = record_batches[0].to_pydict() + names = record_batch_1["name"] + ages = record_batch_1["age"] + children = record_batch_1["child"] + self.assertEqual(names, ["Bharney Rhubble"]) + self.assertEqual(ages, [33]) + self.assertEqual( + children, + [ + [ + {"name": "Whamm-Whamm Rhubble", "age": 3}, + {"name": "Hoppy", "age": 1}, + ], + ], + ) + + record_batch_2 = record_batches[1].to_pydict() + names = record_batch_2["name"] + ages = record_batch_2["age"] + children = record_batch_2["child"] + self.assertEqual(names, ["Wylma Phlyntstone"]) + self.assertEqual(ages, [29]) + self.assertEqual( + children, + [[{"name": "Bepples Phlyntstone", "age": 0}, {"name": "Dino", "age": 4}]], + ) + + @mock.patch("google.cloud.bigquery.table.pyarrow", new=None) + def test_to_arrow_iterable_error_if_pyarrow_is_none(self): + from google.cloud.bigquery.schema import SchemaField + + schema = [ + SchemaField("name", "STRING", mode="REQUIRED"), + SchemaField("age", "INTEGER", mode="REQUIRED"), + ] + rows = [ + {"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]}, + {"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]}, + ] + path = "/foo" + api_request = mock.Mock(return_value={"rows": rows}) + row_iterator = self._make_one(_mock_client(), api_request, path, schema) + + with pytest.raises(ValueError, match="pyarrow"): + row_iterator.to_arrow_iterable() + + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + @unittest.skipIf( + bigquery_storage is None, "Requires `google-cloud-bigquery-storage`" + ) + def test_to_arrow_iterable_w_bqstorage(self): + from google.cloud.bigquery import schema + from google.cloud.bigquery import table as mut + from google.cloud.bigquery_storage_v1 import reader + + bqstorage_client = mock.create_autospec(bigquery_storage.BigQueryReadClient) + bqstorage_client._transport = mock.create_autospec( + big_query_read_grpc_transport.BigQueryReadGrpcTransport + ) + streams = [ + # Use two streams we want to check frames are read from each stream. + {"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"}, + {"name": "/projects/proj/dataset/dset/tables/tbl/streams/5678"}, + ] + session = bigquery_storage.types.ReadSession(streams=streams) + arrow_schema = pyarrow.schema( + [ + pyarrow.field("colA", pyarrow.int64()), + # Not alphabetical to test column order. + pyarrow.field("colC", pyarrow.float64()), + pyarrow.field("colB", pyarrow.string()), + ] + ) + session.arrow_schema.serialized_schema = arrow_schema.serialize().to_pybytes() + bqstorage_client.create_read_session.return_value = session + + mock_rowstream = mock.create_autospec(reader.ReadRowsStream) + bqstorage_client.read_rows.return_value = mock_rowstream + + mock_rows = mock.create_autospec(reader.ReadRowsIterable) + mock_rowstream.rows.return_value = mock_rows + page_items = [ + pyarrow.array([1, -1]), + pyarrow.array([2.0, 4.0]), + pyarrow.array(["abc", "def"]), + ] + + expected_record_batch = pyarrow.RecordBatch.from_arrays( + page_items, schema=arrow_schema + ) + expected_num_record_batches = 3 + + mock_page = mock.create_autospec(reader.ReadRowsPage) + mock_page.to_arrow.return_value = expected_record_batch + mock_pages = (mock_page,) * expected_num_record_batches + type(mock_rows).pages = mock.PropertyMock(return_value=mock_pages) + + schema = [ + schema.SchemaField("colA", "INTEGER"), + schema.SchemaField("colC", "FLOAT"), + schema.SchemaField("colB", "STRING"), + ] + + row_iterator = mut.RowIterator( + _mock_client(), + None, # api_request: ignored + None, # path: ignored + schema, + table=mut.TableReference.from_string("proj.dset.tbl"), + selected_fields=schema, + ) + + record_batches = list( + row_iterator.to_arrow_iterable(bqstorage_client=bqstorage_client) + ) + total_record_batches = len(streams) * len(mock_pages) + self.assertEqual(len(record_batches), total_record_batches) + + for record_batch in record_batches: + # Are the record batches return as expected? + self.assertEqual(record_batch, expected_record_batch) + + # Don't close the client if it was passed in. + bqstorage_client._transport.grpc_channel.close.assert_not_called() + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_to_arrow(self): from google.cloud.bigquery.schema import SchemaField