From 646c2b43b2ccbc609e0d5b85a7e4fbf502bb1243 Mon Sep 17 00:00:00 2001 From: Praful Makani Date: Tue, 18 Feb 2020 23:30:55 +0530 Subject: [PATCH] feat: expose location field of model (#175) --- .../java/com/google/cloud/bigquery/Model.java | 6 ++++++ .../com/google/cloud/bigquery/ModelInfo.java | 20 +++++++++++++++++++ .../google/cloud/bigquery/ModelInfoTest.java | 5 +++++ 3 files changed, 31 insertions(+) diff --git a/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java b/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java index abe1f0f2b..64ef0e4a7 100644 --- a/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java +++ b/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/Model.java @@ -108,6 +108,12 @@ public Builder setLabels(Map labels) { return this; } + @Override + Builder setLocation(String location) { + infoBuilder.setLocation(location); + return this; + } + @Override Builder setTrainingRuns(List trainingRunList) { infoBuilder.setTrainingRuns(trainingRunList); diff --git a/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java b/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java index 5796c820f..83603cbd2 100644 --- a/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java +++ b/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ModelInfo.java @@ -68,6 +68,7 @@ public Model apply(ModelInfo ModelInfo) { private final Long lastModifiedTime; private final Long expirationTime; private final Labels labels; + private final String location; private final ImmutableList trainingRunList; private final ImmutableList featureColumnList; private final ImmutableList labelColumnList; @@ -97,6 +98,8 @@ public abstract static class Builder { */ public abstract Builder setLabels(Map labels); + abstract Builder setLocation(String location); + public abstract Builder setModelId(ModelId modelId); abstract Builder setEtag(String etag); @@ -130,6 +133,7 @@ static class BuilderImpl extends Builder { private Long lastModifiedTime; private Long expirationTime; private Labels labels = Labels.ZERO; + private String location; private List trainingRunList = Collections.emptyList(); private List labelColumnList = Collections.emptyList(); private List featureColumnList = Collections.emptyList(); @@ -150,6 +154,7 @@ static class BuilderImpl extends Builder { this.labelColumnList = modelInfo.labelColumnList; this.featureColumnList = modelInfo.featureColumnList; this.encryptionConfiguration = modelInfo.encryptionConfiguration; + this.location = modelInfo.location; } BuilderImpl(Model modelPb) { @@ -165,6 +170,7 @@ static class BuilderImpl extends Builder { this.lastModifiedTime = modelPb.getLastModifiedTime(); this.expirationTime = modelPb.getExpirationTime(); this.labels = Labels.fromPb(modelPb.getLabels()); + this.location = modelPb.getLocation(); if (modelPb.getTrainingRuns() != null) { this.trainingRunList = modelPb.getTrainingRuns(); } @@ -236,6 +242,12 @@ public Builder setLabels(Map labels) { return this; } + @Override + Builder setLocation(String location) { + this.location = location; + return this; + } + @Override Builder setTrainingRuns(List trainingRunList) { this.trainingRunList = checkNotNull(trainingRunList); @@ -276,6 +288,7 @@ public ModelInfo build() { this.lastModifiedTime = builder.lastModifiedTime; this.expirationTime = builder.expirationTime; this.labels = builder.labels; + this.location = builder.location; this.trainingRunList = ImmutableList.copyOf(builder.trainingRunList); this.labelColumnList = ImmutableList.copyOf(builder.labelColumnList); this.featureColumnList = ImmutableList.copyOf(builder.featureColumnList); @@ -330,6 +343,11 @@ public Map getLabels() { return labels.userMap(); } + /** Returns a location of the model. */ + public String getLocation() { + return location; + } + /** Returns metadata about each training run iteration. */ @BetaApi public ImmutableList getTrainingRuns() { @@ -368,6 +386,7 @@ public String toString() { .add("lastModifiedTime", lastModifiedTime) .add("expirationTime", expirationTime) .add("labels", labels) + .add("location", location) .add("trainingRuns", trainingRunList) .add("labelColumns", labelColumnList) .add("featureColumns", featureColumnList) @@ -416,6 +435,7 @@ Model toPb() { modelPb.setLastModifiedTime(lastModifiedTime); modelPb.setExpirationTime(expirationTime); modelPb.setLabels(labels.toPb()); + modelPb.setLocation(location); modelPb.setTrainingRuns(trainingRunList); if (labelColumnList != null) { modelPb.setLabelColumns(Lists.transform(labelColumnList, StandardSQLField.TO_PB_FUNCTION)); diff --git a/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java b/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java index 2657ccc44..87fa8bbf5 100644 --- a/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java +++ b/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ModelInfoTest.java @@ -33,6 +33,7 @@ public class ModelInfoTest { private static final Long EXPIRATION_TIME = 30L; private static final String DESCRIPTION = "description"; private static final String FRIENDLY_NAME = "friendlyname"; + private static final String LOCATION = "US"; private static final EncryptionConfiguration MODEL_ENCRYPTION_CONFIGURATION = EncryptionConfiguration.newBuilder().setKmsKeyName("KMS_KEY_1").build(); @@ -52,6 +53,7 @@ public class ModelInfoTest { .setFriendlyName(FRIENDLY_NAME) .setTrainingRuns(TRAINING_RUN_LIST) .setEncryptionConfiguration(MODEL_ENCRYPTION_CONFIGURATION) + .setLocation(LOCATION) .build(); @Test @@ -75,6 +77,7 @@ public void testBuilder() { assertEquals(FRIENDLY_NAME, MODEL_INFO.getFriendlyName()); assertEquals(TRAINING_OPTIONS, MODEL_INFO.getTrainingRuns().get(0).getTrainingOptions()); assertEquals(MODEL_ENCRYPTION_CONFIGURATION, MODEL_INFO.getEncryptionConfiguration()); + assertEquals(LOCATION, MODEL_INFO.getLocation()); } @Test @@ -88,6 +91,7 @@ public void testOf() { assertNull(modelInfo.getDescription()); assertNull(modelInfo.getFriendlyName()); assertNull(modelInfo.getEncryptionConfiguration()); + assertNull(modelInfo.getLocation()); assertEquals(modelInfo.getTrainingRuns().isEmpty(), true); assertEquals(modelInfo.getLabelColumns().isEmpty(), true); assertEquals(modelInfo.getFeatureColumns().isEmpty(), true); @@ -113,6 +117,7 @@ private void compareModelInfo(ModelInfo expected, ModelInfo value) { assertEquals(expected.getDescription(), value.getDescription()); assertEquals(expected.getFriendlyName(), value.getFriendlyName()); assertEquals(expected.getLabels(), value.getLabels()); + assertEquals(expected.getLocation(), value.getLocation()); assertEquals(expected.hashCode(), value.hashCode()); assertEquals(expected.getTrainingRuns(), value.getTrainingRuns()); assertEquals(expected.getLabelColumns(), value.getLabelColumns());