From 7983b448158cf8166ada54c60fb896d5658a2162 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Wed, 23 Dec 2020 09:20:02 -0800 Subject: [PATCH] fix: predict image samples params (#150) --- .../predict_image_classification_sample.py | 4 +-- .../predict_image_object_detection_sample.py | 4 +-- .../predict_image_classification_sample.py | 27 +++++++++---------- .../predict_image_object_detection_sample.py | 2 +- 4 files changed, 17 insertions(+), 20 deletions(-) diff --git a/.sample_configs/param_handlers/predict_image_classification_sample.py b/.sample_configs/param_handlers/predict_image_classification_sample.py index ca8f00dc13..abecfd94dc 100644 --- a/.sample_configs/param_handlers/predict_image_classification_sample.py +++ b/.sample_configs/param_handlers/predict_image_classification_sample.py @@ -32,8 +32,8 @@ def make_instances(filename: str) -> typing.Sequence[google.protobuf.struct_pb2. def make_parameters() -> google.protobuf.struct_pb2.Value: # See gs://google-cloud-aiplatform/schema/predict/params/image_classification_1.0.0.yaml for the format of the parameters. parameters_dict = { - "confidence_threshold": 0.5, - "max_predictions": 5 + "confidenceThreshold": 0.5, + "maxPredictions": 5 } parameters = to_protobuf_value(parameters_dict) diff --git a/.sample_configs/param_handlers/predict_image_object_detection_sample.py b/.sample_configs/param_handlers/predict_image_object_detection_sample.py index 975558e1ab..cd897bfa5d 100644 --- a/.sample_configs/param_handlers/predict_image_object_detection_sample.py +++ b/.sample_configs/param_handlers/predict_image_object_detection_sample.py @@ -32,8 +32,8 @@ def make_instances(filename: str) -> typing.Sequence[google.protobuf.struct_pb2. def make_parameters() -> google.protobuf.struct_pb2.Value: # See gs://google-cloud-aiplatform/schema/predict/params/image_object_detection_1.0.0.yaml for the format of the parameters. parameters_dict = { - "confidence_threshold": 0.5, - "max_predictions": 5 + "confidenceThreshold": 0.5, + "maxPredictions": 5 } parameters = to_protobuf_value(parameters_dict) diff --git a/samples/snippets/predict_image_classification_sample.py b/samples/snippets/predict_image_classification_sample.py index 126c21664b..958de5e156 100644 --- a/samples/snippets/predict_image_classification_sample.py +++ b/samples/snippets/predict_image_classification_sample.py @@ -16,7 +16,8 @@ import base64 from google.cloud import aiplatform -from google.cloud.aiplatform.schema import predict +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Value def predict_image_classification_sample( @@ -36,29 +37,25 @@ def predict_image_classification_sample( # The format of each instance should conform to the deployed model's prediction input schema. encoded_content = base64.b64encode(file_content).decode("utf-8") + instance_dict = {"content": encoded_content} - instance_obj = predict.instance.ImageClassificationPredictionInstance( - content=encoded_content) - - instance_val = instance_obj.to_value() - instances = [instance_val] - - params_obj = predict.params.ImageClassificationPredictionParams( - confidence_threshold=0.5, max_predictions=5) - + instance = json_format.ParseDict(instance_dict, Value()) + instances = [instance] + # See gs://google-cloud-aiplatform/schema/predict/params/image_classification_1.0.0.yaml for the format of the parameters. + parameters_dict = {"confidenceThreshold": 0.5, "maxPredictions": 5} + parameters = json_format.ParseDict(parameters_dict, Value()) endpoint = client.endpoint_path( project=project, location=location, endpoint=endpoint_id ) response = client.predict( - endpoint=endpoint, instances=instances, parameters=params_obj + endpoint=endpoint, instances=instances, parameters=parameters ) print("response") - print("\tdeployed_model_id:", response.deployed_model_id) + print(" deployed_model_id:", response.deployed_model_id) # See gs://google-cloud-aiplatform/schema/predict/prediction/classification.yaml for the format of the predictions. predictions = response.predictions - for prediction_ in predictions: - prediction_obj = predict.prediction.ClassificationPredictionResult.from_map(prediction_) - print(prediction_obj) + for prediction in predictions: + print(" prediction:", dict(prediction)) # [END aiplatform_predict_image_classification_sample] diff --git a/samples/snippets/predict_image_object_detection_sample.py b/samples/snippets/predict_image_object_detection_sample.py index 82561581b3..7b1f9afd1a 100644 --- a/samples/snippets/predict_image_object_detection_sample.py +++ b/samples/snippets/predict_image_object_detection_sample.py @@ -42,7 +42,7 @@ def predict_image_object_detection_sample( instance = json_format.ParseDict(instance_dict, Value()) instances = [instance] # See gs://google-cloud-aiplatform/schema/predict/params/image_object_detection_1.0.0.yaml for the format of the parameters. - parameters_dict = {"confidence_threshold": 0.5, "max_predictions": 5} + parameters_dict = {"confidenceThreshold": 0.5, "maxPredictions": 5} parameters = json_format.ParseDict(parameters_dict, Value()) endpoint = client.endpoint_path( project=project, location=location, endpoint=endpoint_id