Skip to content
This repository has been archived by the owner on Sep 9, 2023. It is now read-only.

Commit

Permalink
feat: adds ValueConverter utility and demo samples (#108)
Browse files Browse the repository at this point in the history
* feat: adds value converter utility class and demo samples

* feat: samples updated for EJCL

* fix: removed local file references

* feat: adds ValueConverter tests

Co-authored-by: yoshi-code-bot <70984784+yoshi-code-bot@users.noreply.github.com>
  • Loading branch information
telpirion and yoshi-code-bot committed Dec 18, 2020
1 parent e8d357a commit cf0b763
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 30 deletions.
9 changes: 9 additions & 0 deletions google-cloud-aiplatform/pom.xml
Expand Up @@ -83,6 +83,15 @@
<classifier>testlib</classifier>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java-util</artifactId>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<profiles>
Expand Down
@@ -0,0 +1,62 @@
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.aiplatform.util;

import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;

/**
* Exposes utility methods for converting AI Platform messages to and from
* {@com.google.protobuf.Value} objects.
*/
public class ValueConverter {

/** An empty {@com.google.protobuf.Value} message. */
public static final Value EMPTY_VALUE = Value.newBuilder().build();

/**
* Converts a message type to a {@com.google.protobuf.Value}.
*
* @param message the message to convert
* @return the message as a {@com.google.protobuf.Value}
* @throws InvalidProtocolBufferException
*/
public static Value toValue(Message message) throws InvalidProtocolBufferException {
String jsonString = JsonFormat.printer().print(message);
Value.Builder value = Value.newBuilder();
JsonFormat.parser().merge(jsonString, value);
return value.build();
}

/**
* Converts a {@com.google.protobuf.Value} to a {@com.google.protobuf.Message} of the provided
* {@com.google.protobuf.Message.Builder}.
*
* @param messageBuilder a builder for the message type
* @param value the Value to convert to a message
* @return the value as a message
* @throws InvalidProtocolBufferException
*/
public static Message fromValue(Message.Builder messageBuilder, Value value)
throws InvalidProtocolBufferException {
String valueString = JsonFormat.printer().print(value);
JsonFormat.parser().merge(valueString, messageBuilder);
return messageBuilder.build();
}
}
@@ -0,0 +1,123 @@
/*
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.aiplatform.util;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;

import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.gson.JsonObject;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MapEntry;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.util.Collection;
import org.junit.Test;
import org.junit.function.ThrowingRunnable;

public class ValueConverterTest {

@Test
public void testValueConverterToValue() throws InvalidProtocolBufferException {
AutoMlImageClassificationInputs testObjectInputs =
AutoMlImageClassificationInputs.newBuilder()
.setModelType(ModelType.CLOUD)
.setBudgetMilliNodeHours(8000)
.setMultiLabel(true)
.setDisableEarlyStopping(false)
.build();

Value actualConvertedValue = ValueConverter.toValue(testObjectInputs);

Struct actualStruct = actualConvertedValue.getStructValue();
assertEquals(3, actualStruct.getFieldsCount());

Collection<Object> innerFields = actualStruct.getAllFields().values();
Collection<MapEntry> fieldEntries = (Collection<MapEntry>) innerFields.toArray()[0];

MapEntry actualBoolValueEntry = null;
MapEntry actualStringValueEntry = null;
MapEntry actualNumberValueEntry = null;

for (MapEntry entry : fieldEntries) {
String key = entry.getKey().toString();
if (key.equals("multiLabel")) {
actualBoolValueEntry = entry;
} else if (key.equals("modelType")) {
actualStringValueEntry = entry;
} else if (key.equals("budgetMilliNodeHours")) {
actualNumberValueEntry = entry;
}
}

Value actualBoolValue = (Value) actualBoolValueEntry.getValue();
assertEquals(testObjectInputs.getMultiLabel(), actualBoolValue.getBoolValue());

Value actualStringValue = (Value) actualStringValueEntry.getValue();
assertEquals("CLOUD", actualStringValue.getStringValue());

Value actualNumberValue = (Value) actualNumberValueEntry.getValue();
// protobuf stores int64 values as strings rather than numbers
long actualNumber = Long.parseLong(actualNumberValue.getStringValue());
assertEquals(testObjectInputs.getBudgetMilliNodeHours(), actualNumber);
}

@Test
public void testValueConverterFromValue() throws InvalidProtocolBufferException {

JsonObject testJsonInputs = new JsonObject();
testJsonInputs.addProperty("multi_label", true);
testJsonInputs.addProperty("model_type", "CLOUD");
testJsonInputs.addProperty("budget_milli_node_hours", 8000);

Value.Builder valueBuilder = Value.newBuilder();
JsonFormat.parser().merge(testJsonInputs.toString(), valueBuilder);
Value testValueInputs = valueBuilder.build();

AutoMlImageClassificationInputs actualInputs =
(AutoMlImageClassificationInputs)
ValueConverter.fromValue(AutoMlImageClassificationInputs.newBuilder(), testValueInputs);

assertEquals(8000, actualInputs.getBudgetMilliNodeHours());
assertEquals(true, actualInputs.getMultiLabel());
assertEquals(ModelType.CLOUD, actualInputs.getModelType());
}

@Test
public void testValueConverterFromValueWithBadInputs() throws InvalidProtocolBufferException {
JsonObject testBadJsonInputs = new JsonObject();
testBadJsonInputs.addProperty("wrong_key", "some_value");

Value.Builder badValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(testBadJsonInputs.toString(), badValueBuilder);
final Value testBadValueInputs = badValueBuilder.build();

assertThrows(
InvalidProtocolBufferException.class,
new ThrowingRunnable() {
@Override
public void run() throws Throwable {
AutoMlImageClassificationInputs actualBadInput =
(AutoMlImageClassificationInputs)
ValueConverter.fromValue(
AutoMlImageClassificationInputs.newBuilder(), testBadValueInputs);
}
});
}
}
2 changes: 1 addition & 1 deletion samples/snippets/pom.xml
Expand Up @@ -27,7 +27,7 @@
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-aiplatform</artifactId>
<version>0.1.0</version>
<version>0.1.1-SNAPSHOT</version>
</dependency>
<!-- [END aiplatform_install_with_bom] -->
<dependency>
Expand Down
Expand Up @@ -17,7 +17,7 @@
package aiplatform;

// [START aiplatform_create_training_pipeline_image_classification_sample]

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
import com.google.cloud.aiplatform.v1beta1.EnvVar;
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
Expand All @@ -38,8 +38,8 @@
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

Expand Down Expand Up @@ -74,11 +74,13 @@ static void createTrainingPipelineImageClassificationSample(
+ "automl_image_classification_1.0.0.yaml";
LocationName locationName = LocationName.of(project, location);

String jsonString =
"{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000,"
+ " \"disableEarlyStopping\": false}";
Value.Builder trainingTaskInputs = Value.newBuilder();
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
AutoMlImageClassificationInputs autoMlImageClassificationInputs =
AutoMlImageClassificationInputs.newBuilder()
.setModelType(ModelType.CLOUD)
.setMultiLabel(false)
.setBudgetMilliNodeHours(8000)
.setDisableEarlyStopping(false)
.build();

InputDataConfig trainingInputDataConfig =
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
Expand All @@ -87,7 +89,7 @@ static void createTrainingPipelineImageClassificationSample(
TrainingPipeline.newBuilder()
.setDisplayName(trainingPipelineDisplayName)
.setTrainingTaskDefinition(trainingTaskDefinition)
.setTrainingTaskInputs(trainingTaskInputs)
.setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
.setInputDataConfig(trainingInputDataConfig)
.setModelToUpload(model)
.build();
Expand Down
Expand Up @@ -19,12 +19,15 @@
// [START aiplatform_predict_image_classification_sample]

import com.google.api.client.util.Base64;
import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1beta1.EndpointName;
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.ImageClassificationPredictionInstance;
import com.google.cloud.aiplatform.v1beta1.schema.predict.params.ImageClassificationPredictionParams;
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
Expand Down Expand Up @@ -60,23 +63,42 @@ static void predictImageClassification(String project, String fileName, String e
byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
String content = new String(contents, StandardCharsets.UTF_8);

Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();

String contentDict = "{\"content\": \"" + content + "\"}";
Value.Builder instance = Value.newBuilder();
JsonFormat.parser().merge(contentDict, instance);
ImageClassificationPredictionInstance predictionInstance =
ImageClassificationPredictionInstance.newBuilder()
.setContent(content)
.build();

List<Value> instances = new ArrayList<>();
instances.add(instance.build());
instances.add(ValueConverter.toValue(predictionInstance));

ImageClassificationPredictionParams predictionParams =
ImageClassificationPredictionParams.newBuilder()
.setConfidenceThreshold((float) 0.5)
.setMaxPredictions(5)
.build();

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameter);
predictionServiceClient.predict(endpointName, instances,
ValueConverter.toValue(predictionParams));
System.out.println("Predict Image Classification Response");
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());

System.out.println("Predictions");
for (Value prediction : predictResponse.getPredictionsList()) {
System.out.format("\tPrediction: %s\n", prediction);

ClassificationPredictionResult.Builder resultBuilder =
ClassificationPredictionResult.newBuilder();
// Display names and confidences values correspond to
// IDs in the ID list.
ClassificationPredictionResult result =
(ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction);
int counter = 0;
for (Long id : result.getIdsList()) {
System.out.printf("Label ID: %d\n", id);
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
counter++;
}
}
}
}
Expand Down
Expand Up @@ -17,13 +17,14 @@
package aiplatform;

// [START aiplatform_predict_text_classification_sample]

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1beta1.EndpointName;
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.TextClassificationPredictionInstance;
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -52,25 +53,38 @@ static void predictTextClassificationSingleLabel(
try (PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings)) {
String location = "us-central1";
String jsonString = "{\"content\": \"" + content + "\"}";

EndpointName endpointName = EndpointName.of(project, location, endpointId);

Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
Value.Builder instance = Value.newBuilder();
JsonFormat.parser().merge(jsonString, instance);
TextClassificationPredictionInstance predictionInstance = TextClassificationPredictionInstance
.newBuilder()
.setContent(content)
.build();

List<Value> instances = new ArrayList<>();
instances.add(instance.build());
instances.add(ValueConverter.toValue(predictionInstance));

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameter);
predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE);
System.out.println("Predict Text Classification Response");
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());

System.out.println("Predictions");
System.out.println("Predictions:\n\n");
for (Value prediction : predictResponse.getPredictionsList()) {
System.out.format("\tPrediction: %s\n", prediction);

ClassificationPredictionResult.Builder resultBuilder =
ClassificationPredictionResult.newBuilder();

// Display names and confidences values correspond to
// IDs in the ID list.
ClassificationPredictionResult result =
(ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction);
int counter = 0;
for (Long id : result.getIdsList()) {
System.out.printf("Label ID: %d\n", id);
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
counter++;
}
}
}
}
Expand Down

0 comments on commit cf0b763

Please sign in to comment.