diff --git a/google/cloud/automl_v1beta1/tables/tables_client.py b/google/cloud/automl_v1beta1/tables/tables_client.py index 32137db2..378eca61 100644 --- a/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/google/cloud/automl_v1beta1/tables/tables_client.py @@ -18,11 +18,12 @@ import pkg_resources import logging +import six from google.api_core.gapic_v1 import client_info from google.api_core import exceptions from google.cloud.automl_v1beta1 import gapic -from google.cloud.automl_v1beta1.proto import data_types_pb2, data_items_pb2 +from google.cloud.automl_v1beta1.proto import data_items_pb2 from google.cloud.automl_v1beta1.tables import gcs_client from google.protobuf import struct_pb2 @@ -31,6 +32,61 @@ _LOGGER = logging.getLogger(__name__) +def to_proto_value(value): + """translates a Python value to a google.protobuf.Value. + + Args: + value: The Python value to be translated. + + Returns: + Tuple of the translated google.protobuf.Value and error if any. + """ + # possible Python types (this is a Python3 module): + # https://simplejson.readthedocs.io/en/latest/#encoders-and-decoders + # JSON Python 2 Python 3 + # object dict dict + # array list list + # string unicode str + # number (int) int, long int + # number (real) float float + # true True True + # false False False + # null None None + if value is None: + # translate null to an empty value. + return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), None + elif isinstance(value, bool): + # This check needs to happen before isinstance(value, int), + # isinstance(value, int) returns True when value is bool. + return struct_pb2.Value(bool_value=value), None + elif isinstance(value, six.integer_types) or isinstance(value, float): + return struct_pb2.Value(number_value=value), None + elif isinstance(value, six.string_types) or isinstance(value, six.text_type): + return struct_pb2.Value(string_value=value), None + elif isinstance(value, dict): + struct_value = struct_pb2.Struct() + for key, v in value.items(): + field_value, err = to_proto_value(v) + if err is not None: + return None, err + + struct_value.fields[key].CopyFrom(field_value) + return struct_pb2.Value(struct_value=struct_value), None + elif isinstance(value, list): + list_value = [] + for v in value: + proto_value, err = to_proto_value(v) + if err is not None: + return None, err + list_value.append(proto_value) + return ( + struct_pb2.Value(list_value=struct_pb2.ListValue(values=list_value)), + None, + ) + else: + return None, "unsupport data type: {}".format(type(value)) + + class TablesClient(object): """ AutoML Tables API helper. @@ -404,42 +460,6 @@ def __column_spec_name_from_args( return column_spec_name - def __data_type_to_proto_value(self, data_type, value): - type_code = data_type.type_code - if value is None: - return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - elif type_code == data_types_pb2.FLOAT64: - return struct_pb2.Value(number_value=value) - elif ( - type_code == data_types_pb2.TIMESTAMP - or type_code == data_types_pb2.STRING - or type_code == data_types_pb2.CATEGORY - ): - return struct_pb2.Value(string_value=value) - elif type_code == data_types_pb2.ARRAY: - if isinstance(value, struct_pb2.ListValue): - # in case the user passed in a ListValue. - return struct_pb2.Value(list_value=value) - array = [] - for item in value: - array.append( - self.__data_type_to_proto_value(data_type.list_element_type, item) - ) - return struct_pb2.Value(list_value=struct_pb2.ListValue(values=array)) - elif type_code == data_types_pb2.STRUCT: - if isinstance(value, struct_pb2.Struct): - # in case the user passed in a Struct. - return struct_pb2.Value(struct_value=value) - struct_value = struct_pb2.Struct() - for k, v in value.items(): - field_value = self.__data_type_to_proto_value( - data_type.struct_type.fields[k], v - ) - struct_value.fields[k].CopyFrom(field_value) - return struct_pb2.Value(struct_value=struct_value) - else: - raise ValueError("Unknown type_code: {}".format(type_code)) - def __ensure_gcs_client_is_initialized(self, credentials, project): """Checks if GCS client is initialized. Initializes it if not. @@ -2714,7 +2734,9 @@ def predict( values = [] for i, c in zip(inputs, column_specs): - value_type = self.__data_type_to_proto_value(c.data_type, i) + value_type, err = to_proto_value(i) + if err is not None: + raise ValueError(err) values.append(value_type) row = data_items_pb2.Row(values=values)