Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
perf: improve streaming performance (#240)
* perf: improve streaming performance by using raw pbs

* refactor: remove unused import

Co-authored-by: larkee <larkee@users.noreply.github.com>
  • Loading branch information
larkee and larkee committed Feb 23, 2021
1 parent 434967e commit 3e35d4a
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 365 deletions.
77 changes: 33 additions & 44 deletions google/cloud/spanner_v1/_helpers.py
Expand Up @@ -161,41 +161,6 @@ def _make_list_value_pbs(values):


# pylint: disable=too-many-branches
def _parse_value(value, field_type):
if value is None:
return None
if field_type.code == TypeCode.STRING:
result = value
elif field_type.code == TypeCode.BYTES:
result = value.encode("utf8")
elif field_type.code == TypeCode.BOOL:
result = value
elif field_type.code == TypeCode.INT64:
result = int(value)
elif field_type.code == TypeCode.FLOAT64:
if isinstance(value, str):
result = float(value)
else:
result = value
elif field_type.code == TypeCode.DATE:
result = _date_from_iso8601_date(value)
elif field_type.code == TypeCode.TIMESTAMP:
DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds
result = DatetimeWithNanoseconds.from_rfc3339(value)
elif field_type.code == TypeCode.ARRAY:
result = [_parse_value(item, field_type.array_element_type) for item in value]
elif field_type.code == TypeCode.STRUCT:
result = [
_parse_value(item, field_type.struct_type.fields[i].type_)
for (i, item) in enumerate(value)
]
elif field_type.code == TypeCode.NUMERIC:
result = decimal.Decimal(value)
else:
raise ValueError("Unknown type: %s" % (field_type,))
return result


def _parse_value_pb(value_pb, field_type):
"""Convert a Value protobuf to cell data.
Expand All @@ -209,17 +174,41 @@ def _parse_value_pb(value_pb, field_type):
:returns: value extracted from value_pb
:raises ValueError: if unknown type is passed
"""
type_code = field_type.code
if value_pb.HasField("null_value"):
return None
if value_pb.HasField("string_value"):
return _parse_value(value_pb.string_value, field_type)
if value_pb.HasField("bool_value"):
return _parse_value(value_pb.bool_value, field_type)
if value_pb.HasField("number_value"):
return _parse_value(value_pb.number_value, field_type)
if value_pb.HasField("list_value"):
return _parse_value(value_pb.list_value, field_type)
raise ValueError("No value set in Value: %s" % (value_pb,))
if type_code == TypeCode.STRING:
return value_pb.string_value
elif type_code == TypeCode.BYTES:
return value_pb.string_value.encode("utf8")
elif type_code == TypeCode.BOOL:
return value_pb.bool_value
elif type_code == TypeCode.INT64:
return int(value_pb.string_value)
elif type_code == TypeCode.FLOAT64:
if value_pb.HasField("string_value"):
return float(value_pb.string_value)
else:
return value_pb.number_value
elif type_code == TypeCode.DATE:
return _date_from_iso8601_date(value_pb.string_value)
elif type_code == TypeCode.TIMESTAMP:
DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds
return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value)
elif type_code == TypeCode.ARRAY:
return [
_parse_value_pb(item_pb, field_type.array_element_type)
for item_pb in value_pb.list_value.values
]
elif type_code == TypeCode.STRUCT:
return [
_parse_value_pb(item_pb, field_type.struct_type.fields[i].type_)
for (i, item_pb) in enumerate(value_pb.list_value.values)
]
elif field_type.code == TypeCode.NUMERIC:
return decimal.Decimal(value_pb.string_value)
else:
raise ValueError("Unknown type: %s" % (field_type,))


# pylint: enable=too-many-branches
Expand Down
73 changes: 41 additions & 32 deletions google/cloud/spanner_v1/streamed.py
Expand Up @@ -14,12 +14,15 @@

"""Wrapper for streaming results."""

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
from google.cloud import exceptions
from google.cloud.spanner_v1 import PartialResultSet
from google.cloud.spanner_v1 import TypeCode
import six

# pylint: disable=ungrouped-imports
from google.cloud.spanner_v1._helpers import _parse_value
from google.cloud.spanner_v1._helpers import _parse_value_pb

# pylint: enable=ungrouped-imports

Expand Down Expand Up @@ -88,29 +91,33 @@ def _merge_chunk(self, value):
field = self.fields[current_column]
merged = _merge_by_type(self._pending_chunk, value, field.type_)
self._pending_chunk = None
return _parse_value(merged, field.type_)
return merged

def _merge_values(self, values):
"""Merge values into rows.
:type values: list of :class:`~google.protobuf.struct_pb2.Value`
:param values: non-chunked values from partial result set.
"""
width = len(self.fields)
print(self.fields)
field_types = [field.type_ for field in self.fields]
width = len(field_types)
index = len(self._current_row)
for value in values:
index = len(self._current_row)
field = self.fields[index]
self._current_row.append(_parse_value(value, field.type_))
if len(self._current_row) == width:
self._current_row.append(_parse_value_pb(value, field_types[index]))
index += 1
if index == width:
self._rows.append(self._current_row)
self._current_row = []
index = 0

def _consume_next(self):
"""Consume the next partial result set from the stream.
Parse the result set into new/existing rows in :attr:`_rows`
"""
response = six.next(self._response_iterator)
response_pb = PartialResultSet.pb(response)

if self._metadata is None: # first response
metadata = self._metadata = response.metadata
Expand All @@ -119,29 +126,27 @@ def _consume_next(self):
if source is not None and source._transaction_id is None:
source._transaction_id = metadata.transaction.id

if "stats" in response: # last response
if response_pb.HasField("stats"): # last response
self._stats = response.stats

values = list(response.values)
values = list(response_pb.values)
if self._pending_chunk is not None:
values[0] = self._merge_chunk(values[0])

if response.chunked_value:
if response_pb.chunked_value:
self._pending_chunk = values.pop()

self._merge_values(values)

def __iter__(self):
iter_rows, self._rows[:] = self._rows[:], ()
while True:
if not iter_rows:
try:
self._consume_next()
except StopIteration:
return
iter_rows, self._rows[:] = self._rows[:], ()
iter_rows, self._rows[:] = self._rows[:], ()
while iter_rows:
yield iter_rows.pop(0)
try:
self._consume_next()
except StopIteration:
return

def one(self):
"""Return exactly one result, or raise an exception.
Expand Down Expand Up @@ -213,17 +218,23 @@ def _unmergeable(lhs, rhs, type_):

def _merge_float64(lhs, rhs, type_): # pylint: disable=unused-argument
"""Helper for '_merge_by_type'."""
if type(lhs) == str:
return float(lhs + rhs)
array_continuation = type(lhs) == float and type(rhs) == str and rhs == ""
lhs_kind = lhs.WhichOneof("kind")
if lhs_kind == "string_value":
return Value(string_value=lhs.string_value + rhs.string_value)
rhs_kind = rhs.WhichOneof("kind")
array_continuation = (
lhs_kind == "number_value"
and rhs_kind == "string_value"
and rhs.string_value == ""
)
if array_continuation:
return lhs
raise Unmergeable(lhs, rhs, type_)


def _merge_string(lhs, rhs, type_): # pylint: disable=unused-argument
"""Helper for '_merge_by_type'."""
return str(lhs) + str(rhs)
return Value(string_value=lhs.string_value + rhs.string_value)


_UNMERGEABLE_TYPES = (TypeCode.BOOL,)
Expand All @@ -234,17 +245,17 @@ def _merge_array(lhs, rhs, type_):
element_type = type_.array_element_type
if element_type.code in _UNMERGEABLE_TYPES:
# Individual values cannot be merged, just concatenate
lhs.extend(rhs)
lhs.list_value.values.extend(rhs.list_value.values)
return lhs
lhs, rhs = list(lhs.list_value.values), list(rhs.list_value.values)

# Sanity check: If either list is empty, short-circuit.
# This is effectively a no-op.
if not len(lhs) or not len(rhs):
lhs.extend(rhs)
return lhs
return Value(list_value=ListValue(values=(lhs + rhs)))

first = rhs.pop(0)
if first is None: # can't merge
if first.HasField("null_value"): # can't merge
lhs.append(first)
else:
last = lhs.pop()
Expand All @@ -255,23 +266,22 @@ def _merge_array(lhs, rhs, type_):
lhs.append(first)
else:
lhs.append(merged)
lhs.extend(rhs)
return lhs
return Value(list_value=ListValue(values=(lhs + rhs)))


def _merge_struct(lhs, rhs, type_):
"""Helper for '_merge_by_type'."""
fields = type_.struct_type.fields
lhs, rhs = list(lhs.list_value.values), list(rhs.list_value.values)

# Sanity check: If either list is empty, short-circuit.
# This is effectively a no-op.
if not len(lhs) or not len(rhs):
lhs.extend(rhs)
return lhs
return Value(list_value=ListValue(values=(lhs + rhs)))

candidate_type = fields[len(lhs) - 1].type_
first = rhs.pop(0)
if first is None or candidate_type.code in _UNMERGEABLE_TYPES:
if first.HasField("null_value") or candidate_type.code in _UNMERGEABLE_TYPES:
lhs.append(first)
else:
last = lhs.pop()
Expand All @@ -282,8 +292,7 @@ def _merge_struct(lhs, rhs, type_):
lhs.append(first)
else:
lhs.append(merged)
lhs.extend(rhs)
return lhs
return Value(list_value=ListValue(values=lhs + rhs))


_MERGE_BY_TYPE = {
Expand Down

0 comments on commit 3e35d4a

Please sign in to comment.