Skip to content

Commit

Permalink
fix: Support multiple instances in custom predict sample (#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinnysenthil committed Dec 3, 2021
1 parent dd1f650 commit 8cb4839
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# [START aiplatform_predict_custom_trained_model_sample]
from typing import Dict
from typing import Dict, List, Union

from google.cloud import aiplatform
from google.protobuf import json_format
Expand All @@ -23,18 +23,24 @@
def predict_custom_trained_model_sample(
project: str,
endpoint_id: str,
instance_dict: Dict,
instances: Union[Dict, List[Dict]],
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
"""
`instances` can be either single instance of type dict or a list
of instances.
"""
# The AI Platform services require regional API endpoints.
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for multiple requests.
client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
# The format of each instance should conform to the deployed model's prediction input schema.
instance = json_format.ParseDict(instance_dict, Value())
instances = [instance]
instances = instances if type(instances) == list else [instances]
instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
parameters_dict = {}
parameters = json_format.ParseDict(parameters_dict, Value())
endpoint = client.endpoint_path(
Expand Down
Expand Up @@ -32,9 +32,20 @@ def test_ucaip_generated_predict_custom_trained_model_sample(capsys):

instance_dict = {"image_bytes": {"b64": encoded_content}, "key": "0"}

# Single instance as a dict
predict_custom_trained_model_sample.predict_custom_trained_model_sample(
instance_dict=instance_dict, project=PROJECT_ID, endpoint_id=ENDPOINT_ID
instances=instance_dict, project=PROJECT_ID, endpoint_id=ENDPOINT_ID
)

# Multiple instances in a list
predict_custom_trained_model_sample.predict_custom_trained_model_sample(
instances=[instance_dict, instance_dict],
project=PROJECT_ID,
endpoint_id=ENDPOINT_ID,
)

out, _ = capsys.readouterr()
assert "1.0" in out

# Two sets of scores for multi-instance, one score for single instance
assert out.count("scores") == 3

0 comments on commit 8cb4839

Please sign in to comment.