Skip to content

Commit

Permalink
fix: predict image samples params (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
dizcology committed Dec 23, 2020
1 parent 69fc7fd commit 7983b44
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
Expand Up @@ -32,8 +32,8 @@ def make_instances(filename: str) -> typing.Sequence[google.protobuf.struct_pb2.
def make_parameters() -> google.protobuf.struct_pb2.Value:
# See gs://google-cloud-aiplatform/schema/predict/params/image_classification_1.0.0.yaml for the format of the parameters.
parameters_dict = {
"confidence_threshold": 0.5,
"max_predictions": 5
"confidenceThreshold": 0.5,
"maxPredictions": 5
}
parameters = to_protobuf_value(parameters_dict)

Expand Down
Expand Up @@ -32,8 +32,8 @@ def make_instances(filename: str) -> typing.Sequence[google.protobuf.struct_pb2.
def make_parameters() -> google.protobuf.struct_pb2.Value:
# See gs://google-cloud-aiplatform/schema/predict/params/image_object_detection_1.0.0.yaml for the format of the parameters.
parameters_dict = {
"confidence_threshold": 0.5,
"max_predictions": 5
"confidenceThreshold": 0.5,
"maxPredictions": 5
}
parameters = to_protobuf_value(parameters_dict)

Expand Down
27 changes: 12 additions & 15 deletions samples/snippets/predict_image_classification_sample.py
Expand Up @@ -16,7 +16,8 @@
import base64

from google.cloud import aiplatform
from google.cloud.aiplatform.schema import predict
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value


def predict_image_classification_sample(
Expand All @@ -36,29 +37,25 @@ def predict_image_classification_sample(

# The format of each instance should conform to the deployed model's prediction input schema.
encoded_content = base64.b64encode(file_content).decode("utf-8")
instance_dict = {"content": encoded_content}

instance_obj = predict.instance.ImageClassificationPredictionInstance(
content=encoded_content)

instance_val = instance_obj.to_value()
instances = [instance_val]

params_obj = predict.params.ImageClassificationPredictionParams(
confidence_threshold=0.5, max_predictions=5)

instance = json_format.ParseDict(instance_dict, Value())
instances = [instance]
# See gs://google-cloud-aiplatform/schema/predict/params/image_classification_1.0.0.yaml for the format of the parameters.
parameters_dict = {"confidenceThreshold": 0.5, "maxPredictions": 5}
parameters = json_format.ParseDict(parameters_dict, Value())
endpoint = client.endpoint_path(
project=project, location=location, endpoint=endpoint_id
)
response = client.predict(
endpoint=endpoint, instances=instances, parameters=params_obj
endpoint=endpoint, instances=instances, parameters=parameters
)
print("response")
print("\tdeployed_model_id:", response.deployed_model_id)
print(" deployed_model_id:", response.deployed_model_id)
# See gs://google-cloud-aiplatform/schema/predict/prediction/classification.yaml for the format of the predictions.
predictions = response.predictions
for prediction_ in predictions:
prediction_obj = predict.prediction.ClassificationPredictionResult.from_map(prediction_)
print(prediction_obj)
for prediction in predictions:
print(" prediction:", dict(prediction))


# [END aiplatform_predict_image_classification_sample]
2 changes: 1 addition & 1 deletion samples/snippets/predict_image_object_detection_sample.py
Expand Up @@ -42,7 +42,7 @@ def predict_image_object_detection_sample(
instance = json_format.ParseDict(instance_dict, Value())
instances = [instance]
# See gs://google-cloud-aiplatform/schema/predict/params/image_object_detection_1.0.0.yaml for the format of the parameters.
parameters_dict = {"confidence_threshold": 0.5, "max_predictions": 5}
parameters_dict = {"confidenceThreshold": 0.5, "maxPredictions": 5}
parameters = json_format.ParseDict(parameters_dict, Value())
endpoint = client.endpoint_path(
project=project, location=location, endpoint=endpoint_id
Expand Down

0 comments on commit 7983b44

Please sign in to comment.