Skip to content

Commit

Permalink
feat: expose location field of model (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
Praful Makani committed Feb 18, 2020
1 parent 5212b2f commit 646c2b4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
Expand Up @@ -108,6 +108,12 @@ public Builder setLabels(Map<String, String> labels) {
return this;
}

@Override
Builder setLocation(String location) {
infoBuilder.setLocation(location);
return this;
}

@Override
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
infoBuilder.setTrainingRuns(trainingRunList);
Expand Down
Expand Up @@ -68,6 +68,7 @@ public Model apply(ModelInfo ModelInfo) {
private final Long lastModifiedTime;
private final Long expirationTime;
private final Labels labels;
private final String location;
private final ImmutableList<TrainingRun> trainingRunList;
private final ImmutableList<StandardSQLField> featureColumnList;
private final ImmutableList<StandardSQLField> labelColumnList;
Expand Down Expand Up @@ -97,6 +98,8 @@ public abstract static class Builder {
*/
public abstract Builder setLabels(Map<String, String> labels);

abstract Builder setLocation(String location);

public abstract Builder setModelId(ModelId modelId);

abstract Builder setEtag(String etag);
Expand Down Expand Up @@ -130,6 +133,7 @@ static class BuilderImpl extends Builder {
private Long lastModifiedTime;
private Long expirationTime;
private Labels labels = Labels.ZERO;
private String location;
private List<TrainingRun> trainingRunList = Collections.emptyList();
private List<StandardSQLField> labelColumnList = Collections.emptyList();
private List<StandardSQLField> featureColumnList = Collections.emptyList();
Expand All @@ -150,6 +154,7 @@ static class BuilderImpl extends Builder {
this.labelColumnList = modelInfo.labelColumnList;
this.featureColumnList = modelInfo.featureColumnList;
this.encryptionConfiguration = modelInfo.encryptionConfiguration;
this.location = modelInfo.location;
}

BuilderImpl(Model modelPb) {
Expand All @@ -165,6 +170,7 @@ static class BuilderImpl extends Builder {
this.lastModifiedTime = modelPb.getLastModifiedTime();
this.expirationTime = modelPb.getExpirationTime();
this.labels = Labels.fromPb(modelPb.getLabels());
this.location = modelPb.getLocation();
if (modelPb.getTrainingRuns() != null) {
this.trainingRunList = modelPb.getTrainingRuns();
}
Expand Down Expand Up @@ -236,6 +242,12 @@ public Builder setLabels(Map<String, String> labels) {
return this;
}

@Override
Builder setLocation(String location) {
this.location = location;
return this;
}

@Override
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
this.trainingRunList = checkNotNull(trainingRunList);
Expand Down Expand Up @@ -276,6 +288,7 @@ public ModelInfo build() {
this.lastModifiedTime = builder.lastModifiedTime;
this.expirationTime = builder.expirationTime;
this.labels = builder.labels;
this.location = builder.location;
this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList);
this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList);
this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList);
Expand Down Expand Up @@ -330,6 +343,11 @@ public Map<String, String> getLabels() {
return labels.userMap();
}

/** Returns a location of the model. */
public String getLocation() {
return location;
}

/** Returns metadata about each training run iteration. */
@BetaApi
public ImmutableList<TrainingRun> getTrainingRuns() {
Expand Down Expand Up @@ -368,6 +386,7 @@ public String toString() {
.add("lastModifiedTime", lastModifiedTime)
.add("expirationTime", expirationTime)
.add("labels", labels)
.add("location", location)
.add("trainingRuns", trainingRunList)
.add("labelColumns", labelColumnList)
.add("featureColumns", featureColumnList)
Expand Down Expand Up @@ -416,6 +435,7 @@ Model toPb() {
modelPb.setLastModifiedTime(lastModifiedTime);
modelPb.setExpirationTime(expirationTime);
modelPb.setLabels(labels.toPb());
modelPb.setLocation(location);
modelPb.setTrainingRuns(trainingRunList);
if (labelColumnList != null) {
modelPb.setLabelColumns(Lists.transform(labelColumnList, StandardSQLField.TO_PB_FUNCTION));
Expand Down
Expand Up @@ -33,6 +33,7 @@ public class ModelInfoTest {
private static final Long EXPIRATION_TIME = 30L;
private static final String DESCRIPTION = "description";
private static final String FRIENDLY_NAME = "friendlyname";
private static final String LOCATION = "US";
private static final EncryptionConfiguration MODEL_ENCRYPTION_CONFIGURATION =
EncryptionConfiguration.newBuilder().setKmsKeyName("KMS_KEY_1").build();

Expand All @@ -52,6 +53,7 @@ public class ModelInfoTest {
.setFriendlyName(FRIENDLY_NAME)
.setTrainingRuns(TRAINING_RUN_LIST)
.setEncryptionConfiguration(MODEL_ENCRYPTION_CONFIGURATION)
.setLocation(LOCATION)
.build();

@Test
Expand All @@ -75,6 +77,7 @@ public void testBuilder() {
assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName());
assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions());
assertEquals(MODEL_ENCRYPTION_CONFIGURATION, MODEL_INFO.getEncryptionConfiguration());
assertEquals(LOCATION, MODEL_INFO.getLocation());
}

@Test
Expand All @@ -88,6 +91,7 @@ public void testOf() {
assertNull(modelInfo.getDescription());
assertNull(modelInfo.getFriendlyName());
assertNull(modelInfo.getEncryptionConfiguration());
assertNull(modelInfo.getLocation());
assertEquals(modelInfo.getTrainingRuns().isEmpty(), true);
assertEquals(modelInfo.getLabelColumns().isEmpty(), true);
assertEquals(modelInfo.getFeatureColumns().isEmpty(), true);
Expand All @@ -113,6 +117,7 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) {
assertEquals(expected.getDescription(), value.getDescription());
assertEquals(expected.getFriendlyName(), value.getFriendlyName());
assertEquals(expected.getLabels(), value.getLabels());
assertEquals(expected.getLocation(), value.getLocation());
assertEquals(expected.hashCode(), value.hashCode());
assertEquals(expected.getTrainingRuns(), value.getTrainingRuns());
assertEquals(expected.getLabelColumns(), value.getLabelColumns());
Expand Down

0 comments on commit 646c2b4

Please sign in to comment.