Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose location field of model #175

Merged
merged 1 commit into from Feb 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -108,6 +108,12 @@ public Builder setLabels(Map<String, String> labels) {
return this;
}

@Override
Builder setLocation(String location) {
infoBuilder.setLocation(location);
return this;
}

@Override
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
infoBuilder.setTrainingRuns(trainingRunList);
Expand Down
Expand Up @@ -68,6 +68,7 @@ public Model apply(ModelInfo ModelInfo) {
private final Long lastModifiedTime;
private final Long expirationTime;
private final Labels labels;
private final String location;
private final ImmutableList<TrainingRun> trainingRunList;
private final ImmutableList<StandardSQLField> featureColumnList;
private final ImmutableList<StandardSQLField> labelColumnList;
Expand Down Expand Up @@ -97,6 +98,8 @@ public abstract static class Builder {
*/
public abstract Builder setLabels(Map<String, String> labels);

abstract Builder setLocation(String location);

public abstract Builder setModelId(ModelId modelId);

abstract Builder setEtag(String etag);
Expand Down Expand Up @@ -130,6 +133,7 @@ static class BuilderImpl extends Builder {
private Long lastModifiedTime;
private Long expirationTime;
private Labels labels = Labels.ZERO;
private String location;
private List<TrainingRun> trainingRunList = Collections.emptyList();
private List<StandardSQLField> labelColumnList = Collections.emptyList();
private List<StandardSQLField> featureColumnList = Collections.emptyList();
Expand All @@ -150,6 +154,7 @@ static class BuilderImpl extends Builder {
this.labelColumnList = modelInfo.labelColumnList;
this.featureColumnList = modelInfo.featureColumnList;
this.encryptionConfiguration = modelInfo.encryptionConfiguration;
this.location = modelInfo.location;
}

BuilderImpl(Model modelPb) {
Expand All @@ -165,6 +170,7 @@ static class BuilderImpl extends Builder {
this.lastModifiedTime = modelPb.getLastModifiedTime();
this.expirationTime = modelPb.getExpirationTime();
this.labels = Labels.fromPb(modelPb.getLabels());
this.location = modelPb.getLocation();
if (modelPb.getTrainingRuns() != null) {
this.trainingRunList = modelPb.getTrainingRuns();
}
Expand Down Expand Up @@ -236,6 +242,12 @@ public Builder setLabels(Map<String, String> labels) {
return this;
}

@Override
Builder setLocation(String location) {
this.location = location;
return this;
}

@Override
Builder setTrainingRuns(List<TrainingRun> trainingRunList) {
this.trainingRunList = checkNotNull(trainingRunList);
Expand Down Expand Up @@ -276,6 +288,7 @@ public ModelInfo build() {
this.lastModifiedTime = builder.lastModifiedTime;
this.expirationTime = builder.expirationTime;
this.labels = builder.labels;
this.location = builder.location;
this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList);
this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList);
this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList);
Expand Down Expand Up @@ -330,6 +343,11 @@ public Map<String, String> getLabels() {
return labels.userMap();
}

/** Returns a location of the model. */
public String getLocation() {
return location;
}

/** Returns metadata about each training run iteration. */
@BetaApi
public ImmutableList<TrainingRun> getTrainingRuns() {
Expand Down Expand Up @@ -368,6 +386,7 @@ public String toString() {
.add("lastModifiedTime", lastModifiedTime)
.add("expirationTime", expirationTime)
.add("labels", labels)
.add("location", location)
.add("trainingRuns", trainingRunList)
.add("labelColumns", labelColumnList)
.add("featureColumns", featureColumnList)
Expand Down Expand Up @@ -416,6 +435,7 @@ Model toPb() {
modelPb.setLastModifiedTime(lastModifiedTime);
modelPb.setExpirationTime(expirationTime);
modelPb.setLabels(labels.toPb());
modelPb.setLocation(location);
modelPb.setTrainingRuns(trainingRunList);
if (labelColumnList != null) {
modelPb.setLabelColumns(Lists.transform(labelColumnList, StandardSQLField.TO_PB_FUNCTION));
Expand Down
Expand Up @@ -33,6 +33,7 @@ public class ModelInfoTest {
private static final Long EXPIRATION_TIME = 30L;
private static final String DESCRIPTION = "description";
private static final String FRIENDLY_NAME = "friendlyname";
private static final String LOCATION = "US";
private static final EncryptionConfiguration MODEL_ENCRYPTION_CONFIGURATION =
EncryptionConfiguration.newBuilder().setKmsKeyName("KMS_KEY_1").build();

Expand All @@ -52,6 +53,7 @@ public class ModelInfoTest {
.setFriendlyName(FRIENDLY_NAME)
.setTrainingRuns(TRAINING_RUN_LIST)
.setEncryptionConfiguration(MODEL_ENCRYPTION_CONFIGURATION)
.setLocation(LOCATION)
.build();

@Test
Expand All @@ -75,6 +77,7 @@ public void testBuilder() {
assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName());
assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions());
assertEquals(MODEL_ENCRYPTION_CONFIGURATION, MODEL_INFO.getEncryptionConfiguration());
assertEquals(LOCATION, MODEL_INFO.getLocation());
}

@Test
Expand All @@ -88,6 +91,7 @@ public void testOf() {
assertNull(modelInfo.getDescription());
assertNull(modelInfo.getFriendlyName());
assertNull(modelInfo.getEncryptionConfiguration());
assertNull(modelInfo.getLocation());
assertEquals(modelInfo.getTrainingRuns().isEmpty(), true);
assertEquals(modelInfo.getLabelColumns().isEmpty(), true);
assertEquals(modelInfo.getFeatureColumns().isEmpty(), true);
Expand All @@ -113,6 +117,7 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) {
assertEquals(expected.getDescription(), value.getDescription());
assertEquals(expected.getFriendlyName(), value.getFriendlyName());
assertEquals(expected.getLabels(), value.getLabels());
assertEquals(expected.getLocation(), value.getLocation());
assertEquals(expected.hashCode(), value.hashCode());
assertEquals(expected.getTrainingRuns(), value.getTrainingRuns());
assertEquals(expected.getLabelColumns(), value.getLabelColumns());
Expand Down