From ddc9f7106eab91d4adea2db65e69e3a870a7cd46 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Mon, 23 Mar 2020 08:16:05 -0700 Subject: [PATCH] fix: make TablesClient.predict permissive to input data types (#13) * fix: make TablesClient.predict permissive to input data types The current implementation checks input instance's data type according to column spec's data type. E.g., if the column spec is float, it requires the input to be float or int, but not string. However, this is not the same as tables API contract: float column data type could be string or number values. The current code raises exception with error messages like TypeError: '0' has type str, but expected one of: int, long, float when passed in a string value for numeric columns, which should be allowed. This PR changes the logic so that Python SDK side will be permissive for the input data type - basically all JSON compatible data types are allow. And rely on backend for the validation. * Fix according to comment. * Fix lint. * Address comment: use elif instead of if Co-authored-by: Helin Wang --- .../automl_v1beta1/tables/tables_client.py | 98 ++++++++++++------- 1 file changed, 60 insertions(+), 38 deletions(-) diff --git a/google/cloud/automl_v1beta1/tables/tables_client.py b/google/cloud/automl_v1beta1/tables/tables_client.py index 32137db2..378eca61 100644 --- a/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/google/cloud/automl_v1beta1/tables/tables_client.py @@ -18,11 +18,12 @@ import pkg_resources import logging +import six from google.api_core.gapic_v1 import client_info from google.api_core import exceptions from google.cloud.automl_v1beta1 import gapic -from google.cloud.automl_v1beta1.proto import data_types_pb2, data_items_pb2 +from google.cloud.automl_v1beta1.proto import data_items_pb2 from google.cloud.automl_v1beta1.tables import gcs_client from google.protobuf import struct_pb2 @@ -31,6 +32,61 @@ _LOGGER = logging.getLogger(__name__) +def to_proto_value(value): + """translates a Python value to a google.protobuf.Value. + + Args: + value: The Python value to be translated. + + Returns: + Tuple of the translated google.protobuf.Value and error if any. + """ + # possible Python types (this is a Python3 module): + # https://simplejson.readthedocs.io/en/latest/#encoders-and-decoders + # JSON Python 2 Python 3 + # object dict dict + # array list list + # string unicode str + # number (int) int, long int + # number (real) float float + # true True True + # false False False + # null None None + if value is None: + # translate null to an empty value. + return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), None + elif isinstance(value, bool): + # This check needs to happen before isinstance(value, int), + # isinstance(value, int) returns True when value is bool. + return struct_pb2.Value(bool_value=value), None + elif isinstance(value, six.integer_types) or isinstance(value, float): + return struct_pb2.Value(number_value=value), None + elif isinstance(value, six.string_types) or isinstance(value, six.text_type): + return struct_pb2.Value(string_value=value), None + elif isinstance(value, dict): + struct_value = struct_pb2.Struct() + for key, v in value.items(): + field_value, err = to_proto_value(v) + if err is not None: + return None, err + + struct_value.fields[key].CopyFrom(field_value) + return struct_pb2.Value(struct_value=struct_value), None + elif isinstance(value, list): + list_value = [] + for v in value: + proto_value, err = to_proto_value(v) + if err is not None: + return None, err + list_value.append(proto_value) + return ( + struct_pb2.Value(list_value=struct_pb2.ListValue(values=list_value)), + None, + ) + else: + return None, "unsupport data type: {}".format(type(value)) + + class TablesClient(object): """ AutoML Tables API helper. @@ -404,42 +460,6 @@ def __column_spec_name_from_args( return column_spec_name - def __data_type_to_proto_value(self, data_type, value): - type_code = data_type.type_code - if value is None: - return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) - elif type_code == data_types_pb2.FLOAT64: - return struct_pb2.Value(number_value=value) - elif ( - type_code == data_types_pb2.TIMESTAMP - or type_code == data_types_pb2.STRING - or type_code == data_types_pb2.CATEGORY - ): - return struct_pb2.Value(string_value=value) - elif type_code == data_types_pb2.ARRAY: - if isinstance(value, struct_pb2.ListValue): - # in case the user passed in a ListValue. - return struct_pb2.Value(list_value=value) - array = [] - for item in value: - array.append( - self.__data_type_to_proto_value(data_type.list_element_type, item) - ) - return struct_pb2.Value(list_value=struct_pb2.ListValue(values=array)) - elif type_code == data_types_pb2.STRUCT: - if isinstance(value, struct_pb2.Struct): - # in case the user passed in a Struct. - return struct_pb2.Value(struct_value=value) - struct_value = struct_pb2.Struct() - for k, v in value.items(): - field_value = self.__data_type_to_proto_value( - data_type.struct_type.fields[k], v - ) - struct_value.fields[k].CopyFrom(field_value) - return struct_pb2.Value(struct_value=struct_value) - else: - raise ValueError("Unknown type_code: {}".format(type_code)) - def __ensure_gcs_client_is_initialized(self, credentials, project): """Checks if GCS client is initialized. Initializes it if not. @@ -2714,7 +2734,9 @@ def predict( values = [] for i, c in zip(inputs, column_specs): - value_type = self.__data_type_to_proto_value(c.data_type, i) + value_type, err = to_proto_value(i) + if err is not None: + raise ValueError(err) values.append(value_type) row = data_items_pb2.Row(values=values)