Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: read_session optional to ReadRowsStream.rows() #228

Merged
merged 5 commits into from Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
125 changes: 91 additions & 34 deletions google/cloud/bigquery_storage_v1/reader.py
Expand Up @@ -156,7 +156,7 @@ def _reconnect(self):
read_stream=self._name, offset=self._offset, **self._read_rows_kwargs
)

def rows(self, read_session):
def rows(self, read_session=None):
"""Iterate over all rows in the stream.

This method requires the fastavro library in order to parse row
Expand All @@ -169,19 +169,21 @@ def rows(self, read_session):

Args:
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.

This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.

Returns:
Iterable[Mapping]:
A sequence of rows, represented as dictionaries.
"""
return ReadRowsIterable(self, read_session)
return ReadRowsIterable(self, read_session=read_session)

def to_arrow(self, read_session):
def to_arrow(self, read_session=None):
"""Create a :class:`pyarrow.Table` of all rows in the stream.

This method requires the pyarrow library and a stream using the Arrow
Expand All @@ -191,17 +193,19 @@ def to_arrow(self, read_session):
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.

This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.

Returns:
pyarrow.Table:
A table of all rows in the stream.
"""
return self.rows(read_session).to_arrow()
return self.rows(read_session=read_session).to_arrow()

def to_dataframe(self, read_session, dtypes=None):
def to_dataframe(self, read_session=None, dtypes=None):
"""Create a :class:`pandas.DataFrame` of all rows in the stream.

This method requires the pandas libary to create a data frame and the
Expand All @@ -215,9 +219,11 @@ def to_dataframe(self, read_session, dtypes=None):
read_session ( \
~google.cloud.bigquery_storage_v1.types.ReadSession \
):
The read session associated with this read rows stream. This
contains the schema, which is required to parse the data
messages.
DEPRECATED.

This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
dtypes ( \
Map[str, Union[str, pandas.Series.dtype]] \
):
Expand All @@ -233,7 +239,7 @@ def to_dataframe(self, read_session, dtypes=None):
if pandas is None:
raise ImportError(_PANDAS_REQUIRED)

return self.rows(read_session).to_dataframe(dtypes=dtypes)
return self.rows(read_session=read_session).to_dataframe(dtypes=dtypes)


class ReadRowsIterable(object):
Expand All @@ -242,18 +248,25 @@ class ReadRowsIterable(object):
Args:
reader (google.cloud.bigquery_storage_v1.reader.ReadRowsStream):
A read rows stream.
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
read_session ( \
Optional[~google.cloud.bigquery_storage_v1.types.ReadSession] \
):
DEPRECATED.

This argument was used to specify the schema of the rows in the
stream, but now the first message in a read stream contains
this information.
"""

# This class is modelled after the google.cloud.bigquery.table.RowIterator
# and aims to be API compatible where possible.

def __init__(self, reader, read_session):
def __init__(self, reader, read_session=None):
self._reader = reader
self._read_session = read_session
self._stream_parser = _StreamParser.from_read_session(self._read_session)
if read_session is not None:
self._stream_parser = _StreamParser.from_read_session(read_session)
else:
self._stream_parser = None

@property
def pages(self):
Expand All @@ -266,6 +279,10 @@ def pages(self):
# Each page is an iterator of rows. But also has num_items, remaining,
# and to_dataframe.
for message in self._reader:
# Only the first message contains the schema, which is needed to
# decode the messages.
if not self._stream_parser:
self._stream_parser = _StreamParser.from_read_rows_response(message)
yield ReadRowsPage(self._stream_parser, message)

def __iter__(self):
Expand Down Expand Up @@ -328,10 +345,11 @@ def to_dataframe(self, dtypes=None):
# pandas dataframe is about 2x faster. This is because pandas.concat is
# rarely no-copy, whereas pyarrow.Table.from_batches + to_pandas is
# usually no-copy.
schema_type = self._read_session._pb.WhichOneof("schema")

if schema_type == "arrow_schema":
try:
record_batch = self.to_arrow()
except NotImplementedError:
pass
else:
df = record_batch.to_pandas()
for column in dtypes:
df[column] = pandas.Series(df[column], dtype=dtypes[column])
Expand Down Expand Up @@ -491,6 +509,12 @@ def to_dataframe(self, message, dtypes=None):
def to_rows(self, message):
raise NotImplementedError("Not implemented.")

def _parse_avro_schema(self):
raise NotImplementedError("Not implemented.")

def _parse_arrow_schema(self):
raise NotImplementedError("Not implemented.")

@staticmethod
def from_read_session(read_session):
schema_type = read_session._pb.WhichOneof("schema")
Expand All @@ -503,22 +527,38 @@ def from_read_session(read_session):
"Unsupported schema type in read_session: {0}".format(schema_type)
)

@staticmethod
def from_read_rows_response(message):
schema_type = message._pb.WhichOneof("schema")
if schema_type == "avro_schema":
return _AvroStreamParser(message)
elif schema_type == "arrow_schema":
return _ArrowStreamParser(message)
else:
raise TypeError(
"Unsupported schema type in message: {0}".format(schema_type)
)


class _AvroStreamParser(_StreamParser):
"""Helper to parse Avro messages into useful representations."""

def __init__(self, read_session):
def __init__(self, message):
"""Construct an _AvroStreamParser.

Args:
read_session (google.cloud.bigquery_storage_v1.types.ReadSession):
A read session. This is required because it contains the schema
used in the stream messages.
message (Union[
google.cloud.bigquery_storage_v1.types.ReadSession, \
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
]):
Either the first message of data from a read rows stream or a
read session. Both types contain a oneof "schema" field, which
can be used to determine how to deserialize rows.
"""
if fastavro is None:
raise ImportError(_FASTAVRO_REQUIRED)

self._read_session = read_session
self._first_message = message
self._avro_schema_json = None
self._fastavro_schema = None
self._column_names = None
Expand Down Expand Up @@ -548,6 +588,10 @@ def to_dataframe(self, message, dtypes=None):
strings in the fastavro library.

Args:
message ( \
~google.cloud.bigquery_storage_v1.types.ReadRowsResponse \
):
A message containing Avro bytes to parse into a pandas DataFrame.
dtypes ( \
Map[str, Union[str, pandas.Series.dtype]] \
):
Expand Down Expand Up @@ -578,10 +622,11 @@ def _parse_avro_schema(self):
if self._avro_schema_json:
return

self._avro_schema_json = json.loads(self._read_session.avro_schema.schema)
self._avro_schema_json = json.loads(self._first_message.avro_schema.schema)
self._column_names = tuple(
(field["name"] for field in self._avro_schema_json["fields"])
)
self._first_message = None

def _parse_fastavro(self):
"""Convert parsed Avro schema to fastavro format."""
Expand Down Expand Up @@ -615,11 +660,22 @@ def to_rows(self, message):


class _ArrowStreamParser(_StreamParser):
def __init__(self, read_session):
def __init__(self, message):
"""Construct an _ArrowStreamParser.

Args:
message (Union[
google.cloud.bigquery_storage_v1.types.ReadSession, \
google.cloud.bigquery_storage_v1.types.ReadRowsResponse, \
]):
Either the first message of data from a read rows stream or a
read session. Both types contain a oneof "schema" field, which
can be used to determine how to deserialize rows.
"""
if pyarrow is None:
raise ImportError(_PYARROW_REQUIRED)

self._read_session = read_session
self._first_message = message
self._schema = None

def to_arrow(self, message):
Expand Down Expand Up @@ -659,6 +715,7 @@ def _parse_arrow_schema(self):
return

self._schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(self._read_session.arrow_schema.serialized_schema)
pyarrow.py_buffer(self._first_message.arrow_schema.serialized_schema)
)
self._column_names = [field.name for field in self._schema]
self._first_message = None
77 changes: 40 additions & 37 deletions tests/system/conftest.py
Expand Up @@ -18,13 +18,41 @@
import os
import uuid

import google.auth
from google.cloud import bigquery
import pytest
import test_utils.prefixer

from . import helpers


prefixer = test_utils.prefixer.Prefixer("python-bigquery-storage", "tests/system")


_TABLE_FORMAT = "projects/{}/datasets/{}/tables/{}"
_ASSETS_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "assets")
_ALL_TYPES_SCHEMA = [
bigquery.SchemaField("string_field", "STRING"),
bigquery.SchemaField("bytes_field", "BYTES"),
bigquery.SchemaField("int64_field", "INT64"),
bigquery.SchemaField("float64_field", "FLOAT64"),
bigquery.SchemaField("numeric_field", "NUMERIC"),
bigquery.SchemaField("bool_field", "BOOL"),
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
bigquery.SchemaField(
"person_struct_field",
"STRUCT",
fields=(
bigquery.SchemaField("name", "STRING"),
bigquery.SchemaField("age", "INT64"),
),
),
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
bigquery.SchemaField("date_field", "DATE"),
bigquery.SchemaField("time_field", "TIME"),
bigquery.SchemaField("datetime_field", "DATETIME"),
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
]


@pytest.fixture(scope="session")
Expand All @@ -38,18 +66,9 @@ def use_mtls():


@pytest.fixture(scope="session")
def credentials(use_mtls):
tseaver marked this conversation as resolved.
Show resolved Hide resolved
import google.auth
from google.oauth2 import service_account

if use_mtls:
# mTLS test uses user credentials instead of service account credentials
creds, _ = google.auth.default()
return creds

# NOTE: the test config in noxfile checks that the env variable is indeed set
filename = os.environ["GOOGLE_APPLICATION_CREDENTIALS"]
return service_account.Credentials.from_service_account_file(filename)
def credentials():
creds, _ = google.auth.default()
return creds


@pytest.fixture()
Expand Down Expand Up @@ -77,8 +96,7 @@ def local_shakespeare_table_reference(project_id, use_mtls):
def dataset(project_id, bq_client):
from google.cloud import bigquery

unique_suffix = str(uuid.uuid4()).replace("-", "_")
dataset_name = "bq_storage_system_tests_" + unique_suffix
dataset_name = prefixer.create_prefix()

dataset_id = "{}.{}".format(project_id, dataset_name)
dataset = bigquery.Dataset(dataset_id)
Expand Down Expand Up @@ -120,35 +138,20 @@ def bq_client(credentials, use_mtls):
return bigquery.Client(credentials=credentials)


@pytest.fixture(scope="session", autouse=True)
def cleanup_datasets(bq_client: bigquery.Client):
for dataset in bq_client.list_datasets():
if prefixer.should_cleanup(dataset.dataset_id):
bq_client.delete_dataset(dataset, delete_contents=True, not_found_ok=True)
tseaver marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture
def all_types_table_ref(project_id, dataset, bq_client):
from google.cloud import bigquery

schema = [
bigquery.SchemaField("string_field", "STRING"),
bigquery.SchemaField("bytes_field", "BYTES"),
bigquery.SchemaField("int64_field", "INT64"),
bigquery.SchemaField("float64_field", "FLOAT64"),
bigquery.SchemaField("numeric_field", "NUMERIC"),
bigquery.SchemaField("bool_field", "BOOL"),
bigquery.SchemaField("geography_field", "GEOGRAPHY"),
bigquery.SchemaField(
"person_struct_field",
"STRUCT",
fields=(
bigquery.SchemaField("name", "STRING"),
bigquery.SchemaField("age", "INT64"),
),
),
bigquery.SchemaField("timestamp_field", "TIMESTAMP"),
bigquery.SchemaField("date_field", "DATE"),
bigquery.SchemaField("time_field", "TIME"),
bigquery.SchemaField("datetime_field", "DATETIME"),
bigquery.SchemaField("string_array_field", "STRING", mode="REPEATED"),
]
bq_table = bigquery.table.Table(
table_ref="{}.{}.complex_records".format(project_id, dataset.dataset_id),
schema=schema,
schema=_ALL_TYPES_SCHEMA,
)

created_table = bq_client.create_table(bq_table)
Expand Down