diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 79a387eac66..0f56431cb37 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -161,41 +161,6 @@ def _make_list_value_pbs(values): # pylint: disable=too-many-branches -def _parse_value(value, field_type): - if value is None: - return None - if field_type.code == TypeCode.STRING: - result = value - elif field_type.code == TypeCode.BYTES: - result = value.encode("utf8") - elif field_type.code == TypeCode.BOOL: - result = value - elif field_type.code == TypeCode.INT64: - result = int(value) - elif field_type.code == TypeCode.FLOAT64: - if isinstance(value, str): - result = float(value) - else: - result = value - elif field_type.code == TypeCode.DATE: - result = _date_from_iso8601_date(value) - elif field_type.code == TypeCode.TIMESTAMP: - DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds - result = DatetimeWithNanoseconds.from_rfc3339(value) - elif field_type.code == TypeCode.ARRAY: - result = [_parse_value(item, field_type.array_element_type) for item in value] - elif field_type.code == TypeCode.STRUCT: - result = [ - _parse_value(item, field_type.struct_type.fields[i].type_) - for (i, item) in enumerate(value) - ] - elif field_type.code == TypeCode.NUMERIC: - result = decimal.Decimal(value) - else: - raise ValueError("Unknown type: %s" % (field_type,)) - return result - - def _parse_value_pb(value_pb, field_type): """Convert a Value protobuf to cell data. @@ -209,17 +174,41 @@ def _parse_value_pb(value_pb, field_type): :returns: value extracted from value_pb :raises ValueError: if unknown type is passed """ + type_code = field_type.code if value_pb.HasField("null_value"): return None - if value_pb.HasField("string_value"): - return _parse_value(value_pb.string_value, field_type) - if value_pb.HasField("bool_value"): - return _parse_value(value_pb.bool_value, field_type) - if value_pb.HasField("number_value"): - return _parse_value(value_pb.number_value, field_type) - if value_pb.HasField("list_value"): - return _parse_value(value_pb.list_value, field_type) - raise ValueError("No value set in Value: %s" % (value_pb,)) + if type_code == TypeCode.STRING: + return value_pb.string_value + elif type_code == TypeCode.BYTES: + return value_pb.string_value.encode("utf8") + elif type_code == TypeCode.BOOL: + return value_pb.bool_value + elif type_code == TypeCode.INT64: + return int(value_pb.string_value) + elif type_code == TypeCode.FLOAT64: + if value_pb.HasField("string_value"): + return float(value_pb.string_value) + else: + return value_pb.number_value + elif type_code == TypeCode.DATE: + return _date_from_iso8601_date(value_pb.string_value) + elif type_code == TypeCode.TIMESTAMP: + DatetimeWithNanoseconds = datetime_helpers.DatetimeWithNanoseconds + return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value) + elif type_code == TypeCode.ARRAY: + return [ + _parse_value_pb(item_pb, field_type.array_element_type) + for item_pb in value_pb.list_value.values + ] + elif type_code == TypeCode.STRUCT: + return [ + _parse_value_pb(item_pb, field_type.struct_type.fields[i].type_) + for (i, item_pb) in enumerate(value_pb.list_value.values) + ] + elif field_type.code == TypeCode.NUMERIC: + return decimal.Decimal(value_pb.string_value) + else: + raise ValueError("Unknown type: %s" % (field_type,)) # pylint: enable=too-many-branches diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index a8b15a8f2bd..5a752e01c7a 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -14,12 +14,15 @@ """Wrapper for streaming results.""" +from google.protobuf.struct_pb2 import ListValue +from google.protobuf.struct_pb2 import Value from google.cloud import exceptions +from google.cloud.spanner_v1 import PartialResultSet from google.cloud.spanner_v1 import TypeCode import six # pylint: disable=ungrouped-imports -from google.cloud.spanner_v1._helpers import _parse_value +from google.cloud.spanner_v1._helpers import _parse_value_pb # pylint: enable=ungrouped-imports @@ -88,7 +91,7 @@ def _merge_chunk(self, value): field = self.fields[current_column] merged = _merge_by_type(self._pending_chunk, value, field.type_) self._pending_chunk = None - return _parse_value(merged, field.type_) + return merged def _merge_values(self, values): """Merge values into rows. @@ -96,14 +99,17 @@ def _merge_values(self, values): :type values: list of :class:`~google.protobuf.struct_pb2.Value` :param values: non-chunked values from partial result set. """ - width = len(self.fields) + print(self.fields) + field_types = [field.type_ for field in self.fields] + width = len(field_types) + index = len(self._current_row) for value in values: - index = len(self._current_row) - field = self.fields[index] - self._current_row.append(_parse_value(value, field.type_)) - if len(self._current_row) == width: + self._current_row.append(_parse_value_pb(value, field_types[index])) + index += 1 + if index == width: self._rows.append(self._current_row) self._current_row = [] + index = 0 def _consume_next(self): """Consume the next partial result set from the stream. @@ -111,6 +117,7 @@ def _consume_next(self): Parse the result set into new/existing rows in :attr:`_rows` """ response = six.next(self._response_iterator) + response = PartialResultSet.pb(response) if self._metadata is None: # first response metadata = self._metadata = response.metadata @@ -119,7 +126,7 @@ def _consume_next(self): if source is not None and source._transaction_id is None: source._transaction_id = metadata.transaction.id - if "stats" in response: # last response + if response.HasField("stats"): # last response self._stats = response.stats values = list(response.values) @@ -132,16 +139,14 @@ def _consume_next(self): self._merge_values(values) def __iter__(self): - iter_rows, self._rows[:] = self._rows[:], () while True: - if not iter_rows: - try: - self._consume_next() - except StopIteration: - return - iter_rows, self._rows[:] = self._rows[:], () + iter_rows, self._rows[:] = self._rows[:], () while iter_rows: yield iter_rows.pop(0) + try: + self._consume_next() + except StopIteration: + return def one(self): """Return exactly one result, or raise an exception. @@ -213,9 +218,15 @@ def _unmergeable(lhs, rhs, type_): def _merge_float64(lhs, rhs, type_): # pylint: disable=unused-argument """Helper for '_merge_by_type'.""" - if type(lhs) == str: - return float(lhs + rhs) - array_continuation = type(lhs) == float and type(rhs) == str and rhs == "" + lhs_kind = lhs.WhichOneof("kind") + if lhs_kind == "string_value": + return Value(string_value=lhs.string_value + rhs.string_value) + rhs_kind = rhs.WhichOneof("kind") + array_continuation = ( + lhs_kind == "number_value" + and rhs_kind == "string_value" + and rhs.string_value == "" + ) if array_continuation: return lhs raise Unmergeable(lhs, rhs, type_) @@ -223,7 +234,7 @@ def _merge_float64(lhs, rhs, type_): # pylint: disable=unused-argument def _merge_string(lhs, rhs, type_): # pylint: disable=unused-argument """Helper for '_merge_by_type'.""" - return str(lhs) + str(rhs) + return Value(string_value=lhs.string_value + rhs.string_value) _UNMERGEABLE_TYPES = (TypeCode.BOOL,) @@ -234,17 +245,17 @@ def _merge_array(lhs, rhs, type_): element_type = type_.array_element_type if element_type.code in _UNMERGEABLE_TYPES: # Individual values cannot be merged, just concatenate - lhs.extend(rhs) + lhs.list_value.values.extend(rhs.list_value.values) return lhs + lhs, rhs = list(lhs.list_value.values), list(rhs.list_value.values) # Sanity check: If either list is empty, short-circuit. # This is effectively a no-op. if not len(lhs) or not len(rhs): - lhs.extend(rhs) - return lhs + return Value(list_value=ListValue(values=(lhs + rhs))) first = rhs.pop(0) - if first is None: # can't merge + if first.HasField("null_value"): # can't merge lhs.append(first) else: last = lhs.pop() @@ -255,23 +266,22 @@ def _merge_array(lhs, rhs, type_): lhs.append(first) else: lhs.append(merged) - lhs.extend(rhs) - return lhs + return Value(list_value=ListValue(values=(lhs + rhs))) def _merge_struct(lhs, rhs, type_): """Helper for '_merge_by_type'.""" fields = type_.struct_type.fields + lhs, rhs = list(lhs.list_value.values), list(rhs.list_value.values) # Sanity check: If either list is empty, short-circuit. # This is effectively a no-op. if not len(lhs) or not len(rhs): - lhs.extend(rhs) - return lhs + return Value(list_value=ListValue(values=(lhs + rhs))) candidate_type = fields[len(lhs) - 1].type_ first = rhs.pop(0) - if first is None or candidate_type.code in _UNMERGEABLE_TYPES: + if first.HasField("null_value") or candidate_type.code in _UNMERGEABLE_TYPES: lhs.append(first) else: last = lhs.pop() @@ -282,8 +292,7 @@ def _merge_struct(lhs, rhs, type_): lhs.append(first) else: lhs.append(merged) - lhs.extend(rhs) - return lhs + return Value(list_value=ListValue(values=lhs + rhs)) _MERGE_BY_TYPE = {