Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
feat: read_session optional to ReadRowsStream.rows() (#228)
* feat: `read_session` optional to `ReadRowsStream.rows()`

The schema from the first `ReadRowsResponse` message can be used to decode
messages, instead.

Note: `to_arrow()` and `to_dataframe()` do not work on an empty stream unless a
`read_session` has been passed in, as the schema is not available. This should
not affect `google-cloud-bigquery` and `pandas-gbq`, as those packages use the
lower-level message->dataframe/arrow methods.

* revert change to comment

* use else for empty arrow streams in try-except block

Co-authored-by: Tres Seaver <tseaver@palladion.com>

* update docstring to reflect that readsession and readrowsresponse can be used interchangeably

* update arrow deserializer, too

Co-authored-by: Tres Seaver <tseaver@palladion.com>
  • Loading branch information
tswast and tseaver committed Jul 9, 2021
1 parent a8a8c78 commit 4f56029
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 140 deletions.
125 changes: 91 additions & 34 deletions google/cloud/bigquery_storage_v1/reader.py
Expand Up @@ -156,7 +156,7 @@ def _reconnect(self):
read_stream=self._name, offset=self._offset, **self._read_rows_kwargs
)

def rows(self, read_session):
def rows(self, read_session=None):
"""Iterate over all rows in the stream.
This method requires the fastavro library in order to parse row
Expand All @@ -169,19 +169,21 @@ def rows(self, read_session):
Args:
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
Returns:
Iterable[Mapping]:
A sequence of rows, represented as dictionaries.
"""
return ReadRowsIterable(self, read_session)
return ReadRowsIterable(self, read_session=read_session)

def to_arrow(self, read_session):
def to_arrow(self, read_session=None):
"""Create a :class:`pyarrow.Table` of all rows in the stream.
This method requires the pyarrow library and a stream using the Arrow
Expand All @@ -191,17 +193,19 @@ def to_arrow(self, read_session):
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
Returns:
pyarrow.Table:
A table of all rows in the stream.
"""
return self.rows(read_session).to_arrow()
return self.rows(read_session=read_session).to_arrow()

def to_dataframe(self, read_session, dtypes=None):
def to_dataframe(self, read_session=None, dtypes=None):
"""Create a :class:`pandas.DataFrame` of all rows in the stream.
This method requires the pandas libary to create a data frame and the
Expand All @@ -215,9 +219,11 @@ def to_dataframe(self, read_session, dtypes=None):
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
dtypes ( \
Map[str, Union[str, pandas.Series.dtype]] \
):
Expand All @@ -233,7 +239,7 @@ def to_dataframe(self, read_session, dtypes=None):
if pandas is None:
raise ImportError(_PANDAS_REQUIRED)

return self.rows(read_session).to_dataframe(dtypes=dtypes)
return self.rows(read_session=read_session).to_dataframe(dtypes=dtypes)


class ReadRowsIterable(object):
Expand All @@ -242,18 +248,25 @@ class ReadRowsIterable(object):
Args:
reader (google.cloud.bigquery_storage_v1.reader.ReadRowsStream):
A read rows stream.
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
read_session ( \
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
):
DEPRECATED.
This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
"""

# This class is modelled after the google.cloud.bigquery.table.RowIterator
# and aims to be API compatible where possible.

def __init__(self, reader, read_session):
def __init__(self, reader, read_session=None):
self._reader = reader
self._read_session = read_session
self._stream_parser = _StreamParser.from_read_session(self._read_session)
if read_session is not None:
self._stream_parser = _StreamParser.from_read_session(read_session)
else:
self._stream_parser = None

@property
def pages(self):
Expand All @@ -266,6 +279,10 @@ def pages(self):
# Each page is an iterator of rows. But also has num_items, remaining,
# and to_dataframe.
for message in self._reader:
# Only the first message contains the schema, which is needed to
# decode the messages.
if not self._stream_parser:
self._stream_parser = _StreamParser.from_read_rows_response(message)
yield ReadRowsPage(self._stream_parser, message)

def __iter__(self):
Expand Down Expand Up @@ -328,10 +345,11 @@ def to_dataframe(self, dtypes=None):
# pandas dataframe is about 2x faster. This is because pandas.concat is
# rarely no-copy, whereas pyarrow.Table.from_batches + to_pandas is
# usually no-copy.
schema_type = self._read_session._pb.WhichOneof("schema")

if schema_type == "arrow_schema":
try:
record_batch = self.to_arrow()
except NotImplementedError:
pass
else:
df = record_batch.to_pandas()
for column in dtypes:
df[column] = pandas.Series(df[column], dtype=dtypes[column])
Expand Down Expand Up @@ -491,6 +509,12 @@ def to_dataframe(self, message, dtypes=None):
def to_rows(self, message):
raise NotImplementedError("Not implemented.")

def _parse_avro_schema(self):
raise NotImplementedError("Not implemented.")

def _parse_arrow_schema(self):
raise NotImplementedError("Not implemented.")

@staticmethod
def from_read_session(read_session):
schema_type = read_session._pb.WhichOneof("schema")
Expand All @@ -503,22 +527,38 @@ def from_read_session(read_session):
"Unsupported schema type in read_session: {0}".format(schema_type)
)

@staticmethod
def from_read_rows_response(message):
schema_type = message._pb.WhichOneof("schema")
if schema_type == "avro_schema":
return _AvroStreamParser(message)
elif schema_type == "arrow_schema":
return _ArrowStreamParser(message)
else:
raise TypeError(
"Unsupported schema type in message: {0}".format(schema_type)
)


class _AvroStreamParser(_StreamParser):
"""Helper to parse Avro messages into useful representations."""

def __init__(self, read_session):
def __init__(self, message):
"""Construct an _AvroStreamParser.
Args:
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
message (Union[
google.cloud.bigquery_storage_v1.types.ReadSession, \
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
]):
Either the first message of data from a read rows stream or a
read session. Both types contain a oneof "schema" field, which
can be used to determine how to deserialize rows.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)

self._read_session = read_session
self._first_message = message
self._avro_schema_json = None
self._fastavro_schema = None
self._column_names = None
Expand Down Expand Up @@ -548,6 +588,10 @@ def to_dataframe(self, message, dtypes=None):
strings in the fastavro library.
Args:
message ( \
~google.cloud.bigquery_storage_v1.types.ReadRowsResponse \
):
A message containing Avro bytes to parse into a pandas DataFrame.
dtypes ( \
Map[str, Union[str, pandas.Series.dtype]] \
):
Expand Down Expand Up @@ -578,10 +622,11 @@ def _parse_avro_schema(self):
if self._avro_schema_json:
return

self._avro_schema_json = json.loads(self._read_session.avro_schema.schema)
self._avro_schema_json = json.loads(self._first_message.avro_schema.schema)
self._column_names = tuple(
(field["name"] for field in self._avro_schema_json["fields"])
)
self._first_message = None

def _parse_fastavro(self):
"""Convert parsed Avro schema to fastavro format."""
Expand Down Expand Up @@ -615,11 +660,22 @@ def to_rows(self, message):


class _ArrowStreamParser(_StreamParser):
def __init__(self, read_session):
def __init__(self, message):
"""Construct an _ArrowStreamParser.
Args:
message (Union[
google.cloud.bigquery_storage_v1.types.ReadSession, \
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
]):
Either the first message of data from a read rows stream or a
read session. Both types contain a oneof "schema" field, which
can be used to determine how to deserialize rows.
"""
if pyarrow is None:
raise ImportError(_PYARROW_REQUIRED)

self._read_session = read_session
self._first_message = message
self._schema = None

def to_arrow(self, message):
Expand Down Expand Up @@ -659,6 +715,7 @@ def _parse_arrow_schema(self):
return

self._schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema)
pyarrow.py_buffer(self._first_message.arrow_schema.serialized_schema)
)
self._column_names = [field.name for field in self._schema]
self._first_message = None
77 changes: 40 additions & 37 deletions tests/system/conftest.py
Expand Up @@ -18,13 +18,41 @@
import os
import uuid

import google.auth
from google.cloud import bigquery
import pytest
import test_utils.prefixer

from . import helpers


prefixer = test_utils.prefixer.Prefixer("python-bigquery-storage", "tests/system")


_TABLE_FORMAT = "projects/{}/datasets/{}/tables/{}"
_ASSETS_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "assets")
_ALL_TYPES_SCHEMA = [
bigquery.SchemaField("string_field", "STRING"),
bigquery.SchemaField("bytes_field", "BYTES"),
bigquery.SchemaField("int64_field", "INT64"),
bigquery.SchemaField("float64_field", "FLOAT64"),
bigquery.SchemaField("numeric_field", "NUMERIC"),
bigquery.SchemaField("bool_field", "BOOL"),
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
bigquery.SchemaField(
"person_struct_field",
"STRUCT",
fields=(
bigquery.SchemaField("name", "STRING"),
bigquery.SchemaField("age", "INT64"),
),
),
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
bigquery.SchemaField("date_field", "DATE"),
bigquery.SchemaField("time_field", "TIME"),
bigquery.SchemaField("datetime_field", "DATETIME"),
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
]


@pytest.fixture(scope="session")
Expand All @@ -38,18 +66,9 @@ def use_mtls():


@pytest.fixture(scope="session")
def credentials(use_mtls):
import google.auth
from google.oauth2 import service_account

if use_mtls:
# mTLS test uses user credentials instead of service account credentials
creds, _ = google.auth.default()
return creds

# NOTE: the test config in noxfile checks that the env variable is indeed set
filename = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
return service_account.Credentials.from_service_account_file(filename)
def credentials():
creds, _ = google.auth.default()
return creds


@pytest.fixture()
Expand Down Expand Up @@ -77,8 +96,7 @@ def local_shakespeare_table_reference(project_id, use_mtls):
def dataset(project_id, bq_client):
from google.cloud import bigquery

unique_suffix = str(uuid.uuid4()).replace("-", "_")
dataset_name = "bq_storage_system_tests_" + unique_suffix
dataset_name = prefixer.create_prefix()

dataset_id = "{}.{}".format(project_id, dataset_name)
dataset = bigquery.Dataset(dataset_id)
Expand Down Expand Up @@ -120,35 +138,20 @@ def bq_client(credentials, use_mtls):
return bigquery.Client(credentials=credentials)


@pytest.fixture(scope="session", autouse=True)
def cleanup_datasets(bq_client: bigquery.Client):
for dataset in bq_client.list_datasets():
if prefixer.should_cleanup(dataset.dataset_id):
bq_client.delete_dataset(dataset, delete_contents=True, not_found_ok=True)


@pytest.fixture
def all_types_table_ref(project_id, dataset, bq_client):
from google.cloud import bigquery

schema = [
bigquery.SchemaField("string_field", "STRING"),
bigquery.SchemaField("bytes_field", "BYTES"),
bigquery.SchemaField("int64_field", "INT64"),
bigquery.SchemaField("float64_field", "FLOAT64"),
bigquery.SchemaField("numeric_field", "NUMERIC"),
bigquery.SchemaField("bool_field", "BOOL"),
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
bigquery.SchemaField(
"person_struct_field",
"STRUCT",
fields=(
bigquery.SchemaField("name", "STRING"),
bigquery.SchemaField("age", "INT64"),
),
),
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
bigquery.SchemaField("date_field", "DATE"),
bigquery.SchemaField("time_field", "TIME"),
bigquery.SchemaField("datetime_field", "DATETIME"),
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
]
bq_table = bigquery.table.Table(
table_ref="{}.{}.complex_records".format(project_id, dataset.dataset_id),
schema=schema,
schema=_ALL_TYPES_SCHEMA,
)

created_table = bq_client.create_table(bq_table)
Expand Down

0 comments on commit 4f56029

Please sign in to comment.