From 4f5602950a0c1959e332aa2964245b9caf4828c8 Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Fri, 9 Jul 2021 09:25:04 -0500 Subject: [PATCH] feat: `read_session` optional to `ReadRowsStream.rows()` (#228) * feat: `read_session` optional to `ReadRowsStream.rows()` The schema from the first `ReadRowsResponse` message can be used to decode messages, instead. Note: `to_arrow()` and `to_dataframe()` do not work on an empty stream unless a `read_session` has been passed in, as the schema is not available. This should not affect `google-cloud-bigquery` and `pandas-gbq`, as those packages use the lower-level message->dataframe/arrow methods. * revert change to comment * use else for empty arrow streams in try-except block Co-authored-by: Tres Seaver * update docstring to reflect that readsession and readrowsresponse can be used interchangeably * update arrow deserializer, too Co-authored-by: Tres Seaver --- google/cloud/bigquery_storage_v1/reader.py | 125 +++++++++++++++------ tests/system/conftest.py | 77 +++++++------ tests/unit/test_reader_v1.py | 85 +++++++------- tests/unit/test_reader_v1_arrow.py | 49 ++++---- 4 files changed, 196 insertions(+), 140 deletions(-) diff --git a/google/cloud/bigquery_storage_v1/reader.py b/google/cloud/bigquery_storage_v1/reader.py index 034ad726..a8cd226c 100644 --- a/google/cloud/bigquery_storage_v1/reader.py +++ b/google/cloud/bigquery_storage_v1/reader.py @@ -156,7 +156,7 @@ def _reconnect(self): read_stream=self._name, offset=self._offset, **self._read_rows_kwargs ) - def rows(self, read_session): + def rows(self, read_session=None): """Iterate over all rows in the stream. This method requires the fastavro library in order to parse row @@ -169,19 +169,21 @@ def rows(self, read_session): Args: read_session ( \ - ~google.cloud.bigquery_storage_v1.types.ReadSession \ + Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \ ): - The read session associated with this read rows stream. This - contains the schema, which is required to parse the data - messages. + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. Returns: Iterable[Mapping]: A sequence of rows, represented as dictionaries. """ - return ReadRowsIterable(self, read_session) + return ReadRowsIterable(self, read_session=read_session) - def to_arrow(self, read_session): + def to_arrow(self, read_session=None): """Create a :class:`pyarrow.Table` of all rows in the stream. This method requires the pyarrow library and a stream using the Arrow @@ -191,17 +193,19 @@ def to_arrow(self, read_session): read_session ( \ ~google.cloud.bigquery_storage_v1.types.ReadSession \ ): - The read session associated with this read rows stream. This - contains the schema, which is required to parse the data - messages. + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. Returns: pyarrow.Table: A table of all rows in the stream. """ - return self.rows(read_session).to_arrow() + return self.rows(read_session=read_session).to_arrow() - def to_dataframe(self, read_session, dtypes=None): + def to_dataframe(self, read_session=None, dtypes=None): """Create a :class:`pandas.DataFrame` of all rows in the stream. This method requires the pandas libary to create a data frame and the @@ -215,9 +219,11 @@ def to_dataframe(self, read_session, dtypes=None): read_session ( \ ~google.cloud.bigquery_storage_v1.types.ReadSession \ ): - The read session associated with this read rows stream. This - contains the schema, which is required to parse the data - messages. + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. dtypes ( \ Map[str, Union[str, pandas.Series.dtype]] \ ): @@ -233,7 +239,7 @@ def to_dataframe(self, read_session, dtypes=None): if pandas is None: raise ImportError(_PANDAS_REQUIRED) - return self.rows(read_session).to_dataframe(dtypes=dtypes) + return self.rows(read_session=read_session).to_dataframe(dtypes=dtypes) class ReadRowsIterable(object): @@ -242,18 +248,25 @@ class ReadRowsIterable(object): Args: reader (google.cloud.bigquery_storage_v1.reader.ReadRowsStream): A read rows stream. - read_session (google.cloud.bigquery_storage_v1.types.ReadSession): - A read session. This is required because it contains the schema - used in the stream messages. + read_session ( \ + Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \ + ): + DEPRECATED. + + This argument was used to specify the schema of the rows in the + stream, but now the first message in a read stream contains + this information. """ # This class is modelled after the google.cloud.bigquery.table.RowIterator # and aims to be API compatible where possible. - def __init__(self, reader, read_session): + def __init__(self, reader, read_session=None): self._reader = reader - self._read_session = read_session - self._stream_parser = _StreamParser.from_read_session(self._read_session) + if read_session is not None: + self._stream_parser = _StreamParser.from_read_session(read_session) + else: + self._stream_parser = None @property def pages(self): @@ -266,6 +279,10 @@ def pages(self): # Each page is an iterator of rows. But also has num_items, remaining, # and to_dataframe. for message in self._reader: + # Only the first message contains the schema, which is needed to + # decode the messages. + if not self._stream_parser: + self._stream_parser = _StreamParser.from_read_rows_response(message) yield ReadRowsPage(self._stream_parser, message) def __iter__(self): @@ -328,10 +345,11 @@ def to_dataframe(self, dtypes=None): # pandas dataframe is about 2x faster. This is because pandas.concat is # rarely no-copy, whereas pyarrow.Table.from_batches + to_pandas is # usually no-copy. - schema_type = self._read_session._pb.WhichOneof("schema") - - if schema_type == "arrow_schema": + try: record_batch = self.to_arrow() + except NotImplementedError: + pass + else: df = record_batch.to_pandas() for column in dtypes: df[column] = pandas.Series(df[column], dtype=dtypes[column]) @@ -491,6 +509,12 @@ def to_dataframe(self, message, dtypes=None): def to_rows(self, message): raise NotImplementedError("Not implemented.") + def _parse_avro_schema(self): + raise NotImplementedError("Not implemented.") + + def _parse_arrow_schema(self): + raise NotImplementedError("Not implemented.") + @staticmethod def from_read_session(read_session): schema_type = read_session._pb.WhichOneof("schema") @@ -503,22 +527,38 @@ def from_read_session(read_session): "Unsupported schema type in read_session: {0}".format(schema_type) ) + @staticmethod + def from_read_rows_response(message): + schema_type = message._pb.WhichOneof("schema") + if schema_type == "avro_schema": + return _AvroStreamParser(message) + elif schema_type == "arrow_schema": + return _ArrowStreamParser(message) + else: + raise TypeError( + "Unsupported schema type in message: {0}".format(schema_type) + ) + class _AvroStreamParser(_StreamParser): """Helper to parse Avro messages into useful representations.""" - def __init__(self, read_session): + def __init__(self, message): """Construct an _AvroStreamParser. Args: - read_session (google.cloud.bigquery_storage_v1.types.ReadSession): - A read session. This is required because it contains the schema - used in the stream messages. + message (Union[ + google.cloud.bigquery_storage_v1.types.ReadSession, \ + google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \ + ]): + Either the first message of data from a read rows stream or a + read session. Both types contain a oneof "schema" field, which + can be used to determine how to deserialize rows. """ if fastavro is None: raise ImportError(_FASTAVRO_REQUIRED) - self._read_session = read_session + self._first_message = message self._avro_schema_json = None self._fastavro_schema = None self._column_names = None @@ -548,6 +588,10 @@ def to_dataframe(self, message, dtypes=None): strings in the fastavro library. Args: + message ( \ + ~google.cloud.bigquery_storage_v1.types.ReadRowsResponse \ + ): + A message containing Avro bytes to parse into a pandas DataFrame. dtypes ( \ Map[str, Union[str, pandas.Series.dtype]] \ ): @@ -578,10 +622,11 @@ def _parse_avro_schema(self): if self._avro_schema_json: return - self._avro_schema_json = json.loads(self._read_session.avro_schema.schema) + self._avro_schema_json = json.loads(self._first_message.avro_schema.schema) self._column_names = tuple( (field["name"] for field in self._avro_schema_json["fields"]) ) + self._first_message = None def _parse_fastavro(self): """Convert parsed Avro schema to fastavro format.""" @@ -615,11 +660,22 @@ def to_rows(self, message): class _ArrowStreamParser(_StreamParser): - def __init__(self, read_session): + def __init__(self, message): + """Construct an _ArrowStreamParser. + + Args: + message (Union[ + google.cloud.bigquery_storage_v1.types.ReadSession, \ + google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \ + ]): + Either the first message of data from a read rows stream or a + read session. Both types contain a oneof "schema" field, which + can be used to determine how to deserialize rows. + """ if pyarrow is None: raise ImportError(_PYARROW_REQUIRED) - self._read_session = read_session + self._first_message = message self._schema = None def to_arrow(self, message): @@ -659,6 +715,7 @@ def _parse_arrow_schema(self): return self._schema = pyarrow.ipc.read_schema( - pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema) + pyarrow.py_buffer(self._first_message.arrow_schema.serialized_schema) ) self._column_names = [field.name for field in self._schema] + self._first_message = None diff --git a/tests/system/conftest.py b/tests/system/conftest.py index a18777dd..3a89097a 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -18,13 +18,41 @@ import os import uuid +import google.auth +from google.cloud import bigquery import pytest +import test_utils.prefixer from . import helpers +prefixer = test_utils.prefixer.Prefixer("python-bigquery-storage", "tests/system") + + _TABLE_FORMAT = "projects/{}/datasets/{}/tables/{}" _ASSETS_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "assets") +_ALL_TYPES_SCHEMA = [ + bigquery.SchemaField("string_field", "STRING"), + bigquery.SchemaField("bytes_field", "BYTES"), + bigquery.SchemaField("int64_field", "INT64"), + bigquery.SchemaField("float64_field", "FLOAT64"), + bigquery.SchemaField("numeric_field", "NUMERIC"), + bigquery.SchemaField("bool_field", "BOOL"), + bigquery.SchemaField("geography_field", "GEOGRAPHY"), + bigquery.SchemaField( + "person_struct_field", + "STRUCT", + fields=( + bigquery.SchemaField("name", "STRING"), + bigquery.SchemaField("age", "INT64"), + ), + ), + bigquery.SchemaField("timestamp_field", "TIMESTAMP"), + bigquery.SchemaField("date_field", "DATE"), + bigquery.SchemaField("time_field", "TIME"), + bigquery.SchemaField("datetime_field", "DATETIME"), + bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"), +] @pytest.fixture(scope="session") @@ -38,18 +66,9 @@ def use_mtls(): @pytest.fixture(scope="session") -def credentials(use_mtls): - import google.auth - from google.oauth2 import service_account - - if use_mtls: - # mTLS test uses user credentials instead of service account credentials - creds, _ = google.auth.default() - return creds - - # NOTE: the test config in noxfile checks that the env variable is indeed set - filename = os.environ["GOOGLE_APPLICATION_CREDENTIALS"] - return service_account.Credentials.from_service_account_file(filename) +def credentials(): + creds, _ = google.auth.default() + return creds @pytest.fixture() @@ -77,8 +96,7 @@ def local_shakespeare_table_reference(project_id, use_mtls): def dataset(project_id, bq_client): from google.cloud import bigquery - unique_suffix = str(uuid.uuid4()).replace("-", "_") - dataset_name = "bq_storage_system_tests_" + unique_suffix + dataset_name = prefixer.create_prefix() dataset_id = "{}.{}".format(project_id, dataset_name) dataset = bigquery.Dataset(dataset_id) @@ -120,35 +138,20 @@ def bq_client(credentials, use_mtls): return bigquery.Client(credentials=credentials) +@pytest.fixture(scope="session", autouse=True) +def cleanup_datasets(bq_client: bigquery.Client): + for dataset in bq_client.list_datasets(): + if prefixer.should_cleanup(dataset.dataset_id): + bq_client.delete_dataset(dataset, delete_contents=True, not_found_ok=True) + + @pytest.fixture def all_types_table_ref(project_id, dataset, bq_client): from google.cloud import bigquery - schema = [ - bigquery.SchemaField("string_field", "STRING"), - bigquery.SchemaField("bytes_field", "BYTES"), - bigquery.SchemaField("int64_field", "INT64"), - bigquery.SchemaField("float64_field", "FLOAT64"), - bigquery.SchemaField("numeric_field", "NUMERIC"), - bigquery.SchemaField("bool_field", "BOOL"), - bigquery.SchemaField("geography_field", "GEOGRAPHY"), - bigquery.SchemaField( - "person_struct_field", - "STRUCT", - fields=( - bigquery.SchemaField("name", "STRING"), - bigquery.SchemaField("age", "INT64"), - ), - ), - bigquery.SchemaField("timestamp_field", "TIMESTAMP"), - bigquery.SchemaField("date_field", "DATE"), - bigquery.SchemaField("time_field", "TIME"), - bigquery.SchemaField("datetime_field", "DATETIME"), - bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"), - ] bq_table = bigquery.table.Table( table_ref="{}.{}.complex_records".format(project_id, dataset.dataset_id), - schema=schema, + schema=_ALL_TYPES_SCHEMA, ) created_table = bq_client.create_table(bq_table) diff --git a/tests/unit/test_reader_v1.py b/tests/unit/test_reader_v1.py index 7fb8d5a4..838ef51a 100644 --- a/tests/unit/test_reader_v1.py +++ b/tests/unit/test_reader_v1.py @@ -66,6 +66,7 @@ def mock_gapic_client(): def _bq_to_avro_blocks(bq_blocks, avro_schema_json): avro_schema = fastavro.parse_schema(avro_schema_json) avro_blocks = [] + first_message = True for block in bq_blocks: blockio = six.BytesIO() for row in block: @@ -73,6 +74,9 @@ def _bq_to_avro_blocks(bq_blocks, avro_schema_json): response = types.ReadRowsResponse() response.row_count = len(block) response.avro_rows.serialized_binary_rows = blockio.getvalue() + if first_message: + response.avro_schema = {"schema": json.dumps(avro_schema_json)} + first_message = False avro_blocks.append(response) return avro_blocks @@ -128,54 +132,48 @@ def _bq_to_avro_schema(bq_columns): return avro_schema -def _get_avro_bytes(rows, avro_schema): - avro_file = six.BytesIO() - for row in rows: - fastavro.schemaless_writer(avro_file, avro_schema, row) - return avro_file.getvalue() - - def test_avro_rows_raises_import_error( mut, class_under_test, mock_gapic_client, monkeypatch ): monkeypatch.setattr(mut, "fastavro", None) - reader = class_under_test([], mock_gapic_client, "", 0, {}) - - bq_columns = [{"name": "int_col", "type": "int64"}] - avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) + avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) + avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) + reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + rows = iter(reader.rows()) + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(ImportError): - reader.rows(read_session) + next(rows) def test_rows_no_schema_set_raises_type_error( mut, class_under_test, mock_gapic_client, monkeypatch ): - reader = class_under_test([], mock_gapic_client, "", 0, {}) - read_session = types.ReadSession() + avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) + avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) + avro_blocks[0].avro_schema = None + reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + rows = iter(reader.rows()) + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(TypeError): - reader.rows(read_session) + next(rows) def test_rows_w_empty_stream(class_under_test, mock_gapic_client): - bq_columns = [{"name": "int_col", "type": "int64"}] - avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) reader = class_under_test([], mock_gapic_client, "", 0, {}) - - got = reader.rows(read_session) + got = reader.rows() assert tuple(got) == () def test_rows_w_scalars(class_under_test, mock_gapic_client): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) - got = tuple(reader.rows(read_session)) + got = tuple(reader.rows()) expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS)) assert got == expected @@ -184,7 +182,6 @@ def test_rows_w_scalars(class_under_test, mock_gapic_client): def test_rows_w_timeout(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -206,7 +203,7 @@ def test_rows_w_timeout(class_under_test, mock_gapic_client): ) with pytest.raises(google.api_core.exceptions.DeadlineExceeded): - list(reader.rows(read_session)) + list(reader.rows()) # Don't reconnect on DeadlineException. This allows user-specified timeouts # to be respected. @@ -216,7 +213,6 @@ def test_rows_w_timeout(class_under_test, mock_gapic_client): def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks = [[{"int_col": 1024}, {"int_col": 512}], [{"int_col": 256}]] avro_blocks = _pages_w_nonresumable_internal_error( _bq_to_avro_blocks(bq_blocks, avro_schema) @@ -227,7 +223,7 @@ def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client) with pytest.raises( google.api_core.exceptions.InternalServerError, match="nonresumable error" ): - list(reader.rows(read_session)) + list(reader.rows()) mock_gapic_client.read_rows.assert_not_called() @@ -235,7 +231,6 @@ def test_rows_w_nonresumable_internal_error(class_under_test, mock_gapic_client) def test_rows_w_reconnect(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -258,7 +253,7 @@ def test_rows_w_reconnect(class_under_test, mock_gapic_client): 0, {"metadata": {"test-key": "test-value"}}, ) - got = reader.rows(read_session) + got = reader.rows() expected = tuple( itertools.chain( @@ -280,7 +275,6 @@ def test_rows_w_reconnect(class_under_test, mock_gapic_client): def test_rows_w_reconnect_by_page(class_under_test, mock_gapic_client): bq_columns = [{"name": "int_col", "type": "int64"}] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) bq_blocks_1 = [ [{"int_col": 123}, {"int_col": 234}], [{"int_col": 345}, {"int_col": 456}], @@ -298,7 +292,7 @@ def test_rows_w_reconnect_by_page(class_under_test, mock_gapic_client): 0, {"metadata": {"test-key": "test-value"}}, ) - got = reader.rows(read_session) + got = reader.rows() pages = iter(got.pages) page_1 = next(pages) @@ -330,38 +324,41 @@ def test_to_dataframe_no_pandas_raises_import_error( ): monkeypatch.setattr(mut, "pandas", None) avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) with pytest.raises(ImportError): - reader.to_dataframe(read_session) + reader.to_dataframe() with pytest.raises(ImportError): - reader.rows(read_session).to_dataframe() + reader.rows().to_dataframe() with pytest.raises(ImportError): - next(reader.rows(read_session).pages).to_dataframe() + next(reader.rows().pages).to_dataframe() def test_to_dataframe_no_schema_set_raises_type_error( mut, class_under_test, mock_gapic_client, monkeypatch ): - reader = class_under_test([], mock_gapic_client, "", 0, {}) - read_session = types.ReadSession() + avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) + avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) + avro_blocks[0].avro_schema = None + reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + rows = reader.rows() + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(TypeError): - reader.to_dataframe(read_session) + rows.to_dataframe() def test_to_dataframe_w_scalars(class_under_test): avro_schema = _bq_to_avro_schema(SCALAR_COLUMNS) - read_session = _generate_avro_read_session(avro_schema) avro_blocks = _bq_to_avro_blocks(SCALAR_BLOCKS, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session) + got = reader.to_dataframe() expected = pandas.DataFrame( list(itertools.chain.from_iterable(SCALAR_BLOCKS)), columns=SCALAR_COLUMN_NAMES @@ -392,7 +389,6 @@ def test_to_dataframe_w_dtypes(class_under_test): {"name": "lilfloat", "type": "float64"}, ] ) - read_session = _generate_avro_read_session(avro_schema) blocks = [ [{"bigfloat": 1.25, "lilfloat": 30.5}, {"bigfloat": 2.5, "lilfloat": 21.125}], [{"bigfloat": 3.75, "lilfloat": 11.0}], @@ -400,7 +396,7 @@ def test_to_dataframe_w_dtypes(class_under_test): avro_blocks = _bq_to_avro_blocks(blocks, avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) + got = reader.to_dataframe(dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame( { @@ -421,6 +417,7 @@ def test_to_dataframe_empty_w_scalars_avro(class_under_test): avro_blocks = _bq_to_avro_blocks([], avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session) expected = pandas.DataFrame(columns=SCALAR_COLUMN_NAMES) @@ -448,6 +445,7 @@ def test_to_dataframe_empty_w_dtypes_avro(class_under_test, mock_gapic_client): avro_blocks = _bq_to_avro_blocks([], avro_schema) reader = class_under_test(avro_blocks, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame([], columns=["bigfloat", "lilfloat"]) @@ -466,7 +464,6 @@ def test_to_dataframe_by_page(class_under_test, mock_gapic_client): {"name": "bool_col", "type": "bool"}, ] avro_schema = _bq_to_avro_schema(bq_columns) - read_session = _generate_avro_read_session(avro_schema) block_1 = [{"int_col": 123, "bool_col": True}, {"int_col": 234, "bool_col": False}] block_2 = [{"int_col": 345, "bool_col": True}, {"int_col": 456, "bool_col": False}] block_3 = [{"int_col": 567, "bool_col": True}, {"int_col": 789, "bool_col": False}] @@ -487,7 +484,7 @@ def test_to_dataframe_by_page(class_under_test, mock_gapic_client): 0, {"metadata": {"test-key": "test-value"}}, ) - got = reader.rows(read_session) + got = reader.rows() pages = iter(got.pages) page_1 = next(pages) diff --git a/tests/unit/test_reader_v1_arrow.py b/tests/unit/test_reader_v1_arrow.py index 492098f5..02c7b80a 100644 --- a/tests/unit/test_reader_v1_arrow.py +++ b/tests/unit/test_reader_v1_arrow.py @@ -84,11 +84,17 @@ def _bq_to_arrow_batch_objects(bq_blocks, arrow_schema): def _bq_to_arrow_batches(bq_blocks, arrow_schema): arrow_batches = [] + first_message = True for record_batch in _bq_to_arrow_batch_objects(bq_blocks, arrow_schema): response = types.ReadRowsResponse() response.arrow_record_batch.serialized_record_batch = ( record_batch.serialize().to_pybytes() ) + if first_message: + response.arrow_schema = { + "serialized_schema": arrow_schema.serialize().to_pybytes(), + } + first_message = False arrow_batches.append(response) return arrow_batches @@ -123,14 +129,15 @@ def test_pyarrow_rows_raises_import_error( mut, class_under_test, mock_gapic_client, monkeypatch ): monkeypatch.setattr(mut, "pyarrow", None) - reader = class_under_test([], mock_gapic_client, "", 0, {}) - - bq_columns = [{"name": "int_col", "type": "int64"}] - arrow_schema = _bq_to_arrow_schema(bq_columns) - read_session = _generate_arrow_read_session(arrow_schema) + arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) + arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) + reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + rows = iter(reader.rows()) + # Since session isn't passed in, reader doesn't know serialization type + # until you start iterating. with pytest.raises(ImportError): - reader.rows(read_session) + next(rows) def test_to_arrow_no_pyarrow_raises_import_error( @@ -138,26 +145,24 @@ def test_to_arrow_no_pyarrow_raises_import_error( ): monkeypatch.setattr(mut, "pyarrow", None) arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) with pytest.raises(ImportError): - reader.to_arrow(read_session) + reader.to_arrow() with pytest.raises(ImportError): - reader.rows(read_session).to_arrow() + reader.rows().to_arrow() with pytest.raises(ImportError): - next(reader.rows(read_session).pages).to_arrow() + next(reader.rows().pages).to_arrow() def test_to_arrow_w_scalars_arrow(class_under_test): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) - actual_table = reader.to_arrow(read_session) + actual_table = reader.to_arrow() expected_table = pyarrow.Table.from_batches( _bq_to_arrow_batch_objects(SCALAR_BLOCKS, arrow_schema) ) @@ -166,11 +171,10 @@ def test_to_arrow_w_scalars_arrow(class_under_test): def test_to_dataframe_w_scalars_arrow(class_under_test): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session) + got = reader.to_dataframe() expected = pandas.DataFrame( list(itertools.chain.from_iterable(SCALAR_BLOCKS)), columns=SCALAR_COLUMN_NAMES @@ -183,24 +187,19 @@ def test_to_dataframe_w_scalars_arrow(class_under_test): def test_rows_w_empty_stream_arrow(class_under_test, mock_gapic_client): - bq_columns = [{"name": "int_col", "type": "int64"}] - arrow_schema = _bq_to_arrow_schema(bq_columns) - read_session = _generate_arrow_read_session(arrow_schema) reader = class_under_test([], mock_gapic_client, "", 0, {}) - - got = reader.rows(read_session) + got = reader.rows() assert tuple(got) == () def test_rows_w_scalars_arrow(class_under_test, mock_gapic_client): arrow_schema = _bq_to_arrow_schema(SCALAR_COLUMNS) - read_session = _generate_arrow_read_session(arrow_schema) arrow_batches = _bq_to_arrow_batches(SCALAR_BLOCKS, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) got = tuple( dict((key, value.as_py()) for key, value in row_dict.items()) - for row_dict in reader.rows(read_session) + for row_dict in reader.rows() ) expected = tuple(itertools.chain.from_iterable(SCALAR_BLOCKS)) @@ -214,7 +213,6 @@ def test_to_dataframe_w_dtypes_arrow(class_under_test): {"name": "lilfloat", "type": "float64"}, ] ) - read_session = _generate_arrow_read_session(arrow_schema) blocks = [ [{"bigfloat": 1.25, "lilfloat": 30.5}, {"bigfloat": 2.5, "lilfloat": 21.125}], [{"bigfloat": 3.75, "lilfloat": 11.0}], @@ -222,7 +220,7 @@ def test_to_dataframe_w_dtypes_arrow(class_under_test): arrow_batches = _bq_to_arrow_batches(blocks, arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) - got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) + got = reader.to_dataframe(dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame( { @@ -243,6 +241,7 @@ def test_to_dataframe_empty_w_scalars_arrow(class_under_test): arrow_batches = _bq_to_arrow_batches([], arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session) expected = pandas.DataFrame([], columns=SCALAR_COLUMN_NAMES) @@ -270,6 +269,7 @@ def test_to_dataframe_empty_w_dtypes_arrow(class_under_test, mock_gapic_client): arrow_batches = _bq_to_arrow_batches([], arrow_schema) reader = class_under_test(arrow_batches, mock_gapic_client, "", 0, {}) + # Read session is needed to get a schema for empty streams. got = reader.to_dataframe(read_session, dtypes={"lilfloat": "float16"}) expected = pandas.DataFrame([], columns=["bigfloat", "lilfloat"]) @@ -288,7 +288,6 @@ def test_to_dataframe_by_page_arrow(class_under_test, mock_gapic_client): {"name": "bool_col", "type": "bool"}, ] arrow_schema = _bq_to_arrow_schema(bq_columns) - read_session = _generate_arrow_read_session(arrow_schema) bq_block_1 = [ {"int_col": 123, "bool_col": True}, @@ -315,7 +314,7 @@ def test_to_dataframe_by_page_arrow(class_under_test, mock_gapic_client): reader = class_under_test( _pages_w_unavailable(batch_1), mock_gapic_client, "", 0, {} ) - got = reader.rows(read_session) + got = reader.rows() pages = iter(got.pages) page_1 = next(pages)