diff --git a/google/cloud/automl_v1beta1/tables/tables_client.py b/google/cloud/automl_v1beta1/tables/tables_client.py index fd10538c..32137db2 100644 --- a/google/cloud/automl_v1beta1/tables/tables_client.py +++ b/google/cloud/automl_v1beta1/tables/tables_client.py @@ -22,8 +22,10 @@ from google.api_core.gapic_v1 import client_info from google.api_core import exceptions from google.cloud.automl_v1beta1 import gapic -from google.cloud.automl_v1beta1.proto import data_types_pb2 +from google.cloud.automl_v1beta1.proto import data_types_pb2, data_items_pb2 from google.cloud.automl_v1beta1.tables import gcs_client +from google.protobuf import struct_pb2 + _GAPIC_LIBRARY_VERSION = pkg_resources.get_distribution("google-cloud-automl").version _LOGGER = logging.getLogger(__name__) @@ -402,21 +404,39 @@ def __column_spec_name_from_args( return column_spec_name - def __type_code_to_value_type(self, type_code, value): + def __data_type_to_proto_value(self, data_type, value): + type_code = data_type.type_code if value is None: - return {"null_value": 0} + return struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE) elif type_code == data_types_pb2.FLOAT64: - return {"number_value": value} - elif type_code == data_types_pb2.TIMESTAMP: - return {"string_value": value} - elif type_code == data_types_pb2.STRING: - return {"string_value": value} + return struct_pb2.Value(number_value=value) + elif ( + type_code == data_types_pb2.TIMESTAMP + or type_code == data_types_pb2.STRING + or type_code == data_types_pb2.CATEGORY + ): + return struct_pb2.Value(string_value=value) elif type_code == data_types_pb2.ARRAY: - return {"list_value": value} + if isinstance(value, struct_pb2.ListValue): + # in case the user passed in a ListValue. + return struct_pb2.Value(list_value=value) + array = [] + for item in value: + array.append( + self.__data_type_to_proto_value(data_type.list_element_type, item) + ) + return struct_pb2.Value(list_value=struct_pb2.ListValue(values=array)) elif type_code == data_types_pb2.STRUCT: - return {"struct_value": value} - elif type_code == data_types_pb2.CATEGORY: - return {"string_value": value} + if isinstance(value, struct_pb2.Struct): + # in case the user passed in a Struct. + return struct_pb2.Value(struct_value=value) + struct_value = struct_pb2.Struct() + for k, v in value.items(): + field_value = self.__data_type_to_proto_value( + data_type.struct_type.fields[k], v + ) + struct_value.fields[k].CopyFrom(field_value) + return struct_pb2.Value(struct_value=struct_value) else: raise ValueError("Unknown type_code: {}".format(type_code)) @@ -2694,16 +2714,17 @@ def predict( values = [] for i, c in zip(inputs, column_specs): - value_type = self.__type_code_to_value_type(c.data_type.type_code, i) + value_type = self.__data_type_to_proto_value(c.data_type, i) values.append(value_type) - request = {"row": {"values": values}} + row = data_items_pb2.Row(values=values) + payload = data_items_pb2.ExamplePayload(row=row) params = None if feature_importance: params = {"feature_importance": "true"} - return self.prediction_client.predict(model.name, request, params, **kwargs) + return self.prediction_client.predict(model.name, payload, params, **kwargs) def batch_predict( self, diff --git a/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py b/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py index ce513083..3566846d 100644 --- a/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py +++ b/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py @@ -23,7 +23,8 @@ from google.api_core import exceptions from google.auth.credentials import AnonymousCredentials from google.cloud import automl_v1beta1 -from google.cloud.automl_v1beta1.proto import data_types_pb2 +from google.cloud.automl_v1beta1.proto import data_types_pb2, data_items_pb2 +from google.protobuf import struct_pb2 PROJECT = "project" REGION = "region" @@ -1116,9 +1117,10 @@ def test_predict_from_array(self): model.configure_mock(tables_model_metadata=model_metadata, name="my_model") client = self.tables_client({"get_model.return_value": model}, {}) client.predict(["1"], model_name="my_model") - client.prediction_client.predict.assert_called_with( - "my_model", {"row": {"values": [{"string_value": "1"}]}}, None + payload = data_items_pb2.ExamplePayload( + row=data_items_pb2.Row(values=[struct_pb2.Value(string_value="1")]) ) + client.prediction_client.predict.assert_called_with("my_model", payload, None) def test_predict_from_dict(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) @@ -1131,11 +1133,15 @@ def test_predict_from_dict(self): model.configure_mock(tables_model_metadata=model_metadata, name="my_model") client = self.tables_client({"get_model.return_value": model}, {}) client.predict({"a": "1", "b": "2"}, model_name="my_model") - client.prediction_client.predict.assert_called_with( - "my_model", - {"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}}, - None, + payload = data_items_pb2.ExamplePayload( + row=data_items_pb2.Row( + values=[ + struct_pb2.Value(string_value="1"), + struct_pb2.Value(string_value="2"), + ] + ) ) + client.prediction_client.predict.assert_called_with("my_model", payload, None) def test_predict_from_dict_with_feature_importance(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY) @@ -1150,10 +1156,16 @@ def test_predict_from_dict_with_feature_importance(self): client.predict( {"a": "1", "b": "2"}, model_name="my_model", feature_importance=True ) + payload = data_items_pb2.ExamplePayload( + row=data_items_pb2.Row( + values=[ + struct_pb2.Value(string_value="1"), + struct_pb2.Value(string_value="2"), + ] + ) + ) client.prediction_client.predict.assert_called_with( - "my_model", - {"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}}, - {"feature_importance": "true"}, + "my_model", payload, {"feature_importance": "true"} ) def test_predict_from_dict_missing(self): @@ -1167,18 +1179,32 @@ def test_predict_from_dict_missing(self): model.configure_mock(tables_model_metadata=model_metadata, name="my_model") client = self.tables_client({"get_model.return_value": model}, {}) client.predict({"a": "1"}, model_name="my_model") - client.prediction_client.predict.assert_called_with( - "my_model", - {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}}, - None, + payload = data_items_pb2.ExamplePayload( + row=data_items_pb2.Row( + values=[ + struct_pb2.Value(string_value="1"), + struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), + ] + ) ) + client.prediction_client.predict.assert_called_with("my_model", payload, None) def test_predict_all_types(self): float_type = mock.Mock(type_code=data_types_pb2.FLOAT64) timestamp_type = mock.Mock(type_code=data_types_pb2.TIMESTAMP) string_type = mock.Mock(type_code=data_types_pb2.STRING) - array_type = mock.Mock(type_code=data_types_pb2.ARRAY) - struct_type = mock.Mock(type_code=data_types_pb2.STRUCT) + array_type = mock.Mock( + type_code=data_types_pb2.ARRAY, + list_element_type=mock.Mock(type_code=data_types_pb2.FLOAT64), + ) + struct = data_types_pb2.StructType() + struct.fields["a"].CopyFrom( + data_types_pb2.DataType(type_code=data_types_pb2.CATEGORY) + ) + struct.fields["b"].CopyFrom( + data_types_pb2.DataType(type_code=data_types_pb2.CATEGORY) + ) + struct_type = mock.Mock(type_code=data_types_pb2.STRUCT, struct_type=struct) category_type = mock.Mock(type_code=data_types_pb2.CATEGORY) column_spec_float = mock.Mock(display_name="float", data_type=float_type) column_spec_timestamp = mock.Mock( @@ -1211,29 +1237,33 @@ def test_predict_all_types(self): "timestamp": "EST", "string": "text", "array": [1], - "struct": {"a": "b"}, + "struct": {"a": "label_a", "b": "label_b"}, "category": "a", "null": None, }, model_name="my_model", ) - client.prediction_client.predict.assert_called_with( - "my_model", - { - "row": { - "values": [ - {"number_value": 1.0}, - {"string_value": "EST"}, - {"string_value": "text"}, - {"list_value": [1]}, - {"struct_value": {"a": "b"}}, - {"string_value": "a"}, - {"null_value": 0}, - ] - } - }, - None, + struct = struct_pb2.Struct() + struct.fields["a"].CopyFrom(struct_pb2.Value(string_value="label_a")) + struct.fields["b"].CopyFrom(struct_pb2.Value(string_value="label_b")) + payload = data_items_pb2.ExamplePayload( + row=data_items_pb2.Row( + values=[ + struct_pb2.Value(number_value=1.0), + struct_pb2.Value(string_value="EST"), + struct_pb2.Value(string_value="text"), + struct_pb2.Value( + list_value=struct_pb2.ListValue( + values=[struct_pb2.Value(number_value=1.0)] + ) + ), + struct_pb2.Value(struct_value=struct), + struct_pb2.Value(string_value="a"), + struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE), + ] + ) ) + client.prediction_client.predict.assert_called_with("my_model", payload, None) def test_predict_from_array_missing(self): data_type = mock.Mock(type_code=data_types_pb2.CATEGORY)