From 9e8cd76759e584d743a1d5c310d8cd7299b3a39d Mon Sep 17 00:00:00 2001 From: Praful Makani Date: Tue, 12 May 2020 11:07:22 +0530 Subject: [PATCH] feat: add extract model for extractjobconfiguration (#227) * feat: add extract model * feat: modified code --- .../bigquery/ExtractJobConfiguration.java | 89 ++++++++++++++++++- .../bigquery/ExtractJobConfigurationTest.java | 57 ++++++++++++ .../cloud/bigquery/it/ITBigQueryTest.java | 38 ++++++++ 3 files changed, 180 insertions(+), 4 deletions(-) diff --git a/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ExtractJobConfiguration.java b/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ExtractJobConfiguration.java index 4dcfdaeaf..29a256e9e 100644 --- a/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ExtractJobConfiguration.java +++ b/google-cloud-bigquery/src/main/java/com/google/cloud/bigquery/ExtractJobConfiguration.java @@ -38,6 +38,7 @@ public final class ExtractJobConfiguration extends JobConfiguration { private static final long serialVersionUID = 4147749733166593761L; private final TableId sourceTable; + private final ModelId sourceModel; private final List destinationUris; private final Boolean printHeader; private final String fieldDelimiter; @@ -51,6 +52,7 @@ public static final class Builder extends JobConfiguration.Builder { private TableId sourceTable; + private ModelId sourceModel; private List destinationUris; private Boolean printHeader; private String fieldDelimiter; @@ -67,6 +69,7 @@ private Builder() { private Builder(ExtractJobConfiguration jobInfo) { this(); this.sourceTable = jobInfo.sourceTable; + this.sourceModel = jobInfo.sourceModel; this.destinationUris = jobInfo.destinationUris; this.printHeader = jobInfo.printHeader; this.fieldDelimiter = jobInfo.fieldDelimiter; @@ -80,7 +83,12 @@ private Builder(ExtractJobConfiguration jobInfo) { private Builder(com.google.api.services.bigquery.model.JobConfiguration configurationPb) { this(); JobConfigurationExtract extractConfigurationPb = configurationPb.getExtract(); - this.sourceTable = TableId.fromPb(extractConfigurationPb.getSourceTable()); + if (extractConfigurationPb.getSourceTable() != null) { + this.sourceTable = TableId.fromPb(extractConfigurationPb.getSourceTable()); + } + if (extractConfigurationPb.getSourceModel() != null) { + this.sourceModel = ModelId.fromPb(extractConfigurationPb.getSourceModel()); + } this.destinationUris = extractConfigurationPb.getDestinationUris(); this.printHeader = extractConfigurationPb.getPrintHeader(); this.fieldDelimiter = extractConfigurationPb.getFieldDelimiter(); @@ -101,6 +109,12 @@ public Builder setSourceTable(TableId sourceTable) { return this; } + /** Sets the model to export. */ + public Builder setSourceModel(ModelId sourceModel) { + this.sourceModel = sourceModel; + return this; + } + /** * Sets the list of fully-qualified Google Cloud Storage URIs (e.g. gs://bucket/path) where the * extracted table should be written. @@ -191,7 +205,8 @@ public ExtractJobConfiguration build() { private ExtractJobConfiguration(Builder builder) { super(builder); - this.sourceTable = checkNotNull(builder.sourceTable); + this.sourceTable = builder.sourceTable; + this.sourceModel = builder.sourceModel; this.destinationUris = checkNotNull(builder.destinationUris); this.printHeader = builder.printHeader; this.fieldDelimiter = builder.fieldDelimiter; @@ -207,6 +222,11 @@ public TableId getSourceTable() { return sourceTable; } + /** Returns the model to export. */ + public ModelId getSourceModel() { + return sourceModel; + } + /** * Returns the list of fully-qualified Google Cloud Storage URIs where the extracted table should * be written. @@ -263,6 +283,7 @@ public Builder toBuilder() { ToStringHelper toStringHelper() { return super.toStringHelper() .add("sourceTable", sourceTable) + .add("sourceModel", sourceModel) .add("destinationUris", destinationUris) .add("format", format) .add("printHeader", printHeader) @@ -284,6 +305,7 @@ public int hashCode() { return Objects.hash( baseHashCode(), sourceTable, + sourceModel, destinationUris, printHeader, fieldDelimiter, @@ -296,9 +318,12 @@ public int hashCode() { @Override ExtractJobConfiguration setProjectId(String projectId) { - if (Strings.isNullOrEmpty(getSourceTable().getProject())) { + if (getSourceTable() != null && Strings.isNullOrEmpty(getSourceTable().getProject())) { return toBuilder().setSourceTable(getSourceTable().setProjectId(projectId)).build(); } + if (getSourceModel() != null && Strings.isNullOrEmpty(getSourceModel().getProject())) { + return toBuilder().setSourceModel(getSourceModel().setProjectId(projectId)).build(); + } return this; } @@ -308,7 +333,12 @@ com.google.api.services.bigquery.model.JobConfiguration toPb() { com.google.api.services.bigquery.model.JobConfiguration jobConfiguration = new com.google.api.services.bigquery.model.JobConfiguration(); extractConfigurationPb.setDestinationUris(destinationUris); - extractConfigurationPb.setSourceTable(sourceTable.toPb()); + if (sourceTable != null) { + extractConfigurationPb.setSourceTable(sourceTable.toPb()); + } + if (sourceModel != null) { + extractConfigurationPb.setSourceModel(sourceModel.toPb()); + } extractConfigurationPb.setPrintHeader(printHeader); extractConfigurationPb.setFieldDelimiter(fieldDelimiter); extractConfigurationPb.setDestinationFormat(format); @@ -333,6 +363,15 @@ public static Builder newBuilder(TableId sourceTable, String destinationUri) { return newBuilder(sourceTable, ImmutableList.of(destinationUri)); } + /** + * Creates a builder for a BigQuery Extract Job configuration given source model and destination + * URI. + */ + public static Builder newBuilder(ModelId sourceModel, String destinationUri) { + checkArgument(!isNullOrEmpty(destinationUri), "Provided destinationUri is null or empty"); + return newBuilder(sourceModel, ImmutableList.of(destinationUri)); + } + /** * Creates a builder for a BigQuery Extract Job configuration given source table and destination * URIs. @@ -341,6 +380,14 @@ public static Builder newBuilder(TableId sourceTable, List destinationUr return new Builder().setSourceTable(sourceTable).setDestinationUris(destinationUris); } + /** + * Creates a builder for a BigQuery Extract Job configuration given source model and destination + * URIs. + */ + public static Builder newBuilder(ModelId sourceModel, List destinationUris) { + return new Builder().setSourceModel(sourceModel).setDestinationUris(destinationUris); + } + /** * Returns a BigQuery Extract Job configuration for the given source table and destination URI. */ @@ -348,6 +395,13 @@ public static ExtractJobConfiguration of(TableId sourceTable, String destination return newBuilder(sourceTable, destinationUri).build(); } + /** + * Returns a BigQuery Extract Job configuration for the given source model and destination URI. + */ + public static ExtractJobConfiguration of(ModelId sourceModel, String destinationUri) { + return newBuilder(sourceModel, destinationUri).build(); + } + /** * Returns a BigQuery Extract Job configuration for the given source table and destination URIs. */ @@ -355,6 +409,13 @@ public static ExtractJobConfiguration of(TableId sourceTable, List desti return newBuilder(sourceTable, destinationUris).build(); } + /** + * Returns a BigQuery Extract Job configuration for the given source model and destination URIs. + */ + public static ExtractJobConfiguration of(ModelId sourceModel, List destinationUris) { + return newBuilder(sourceModel, destinationUris).build(); + } + /** * Returns a BigQuery Extract Job configuration for the given source table, format and destination * URI. @@ -365,6 +426,16 @@ public static ExtractJobConfiguration of( return newBuilder(sourceTable, destinationUri).setFormat(format).build(); } + /** + * Returns a BigQuery Extract Job configuration for the given source model, format and destination + * URI. + */ + public static ExtractJobConfiguration of( + ModelId sourceTable, String destinationUri, String format) { + checkArgument(!isNullOrEmpty(format), "Provided format is null or empty"); + return newBuilder(sourceTable, destinationUri).setFormat(format).build(); + } + /** * Returns a BigQuery Extract Job configuration for the given source table, format and destination * URIs. @@ -375,6 +446,16 @@ public static ExtractJobConfiguration of( return newBuilder(sourceTable, destinationUris).setFormat(format).build(); } + /** + * Returns a BigQuery Extract Job configuration for the given source table, format and destination + * URIs. + */ + public static ExtractJobConfiguration of( + ModelId sourceModel, List destinationUris, String format) { + checkArgument(!isNullOrEmpty(format), "Provided format is null or empty"); + return newBuilder(sourceModel, destinationUris).setFormat(format).build(); + } + @SuppressWarnings("unchecked") static ExtractJobConfiguration fromPb( com.google.api.services.bigquery.model.JobConfiguration confPb) { diff --git a/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ExtractJobConfigurationTest.java b/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ExtractJobConfigurationTest.java index 648fa58a4..95142a068 100644 --- a/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ExtractJobConfigurationTest.java +++ b/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/ExtractJobConfigurationTest.java @@ -32,6 +32,7 @@ public class ExtractJobConfigurationTest { private static final List DESTINATION_URIS = ImmutableList.of("uri1", "uri2"); private static final String DESTINATION_URI = "uri1"; private static final TableId TABLE_ID = TableId.of("dataset", "table"); + private static final ModelId MODEL_ID = ModelId.of("dataset", "model"); private static final String FIELD_DELIMITER = ","; private static final String FORMAT = "CSV"; private static final String AVRO_FORMAT = "AVRO"; @@ -70,6 +71,16 @@ public class ExtractJobConfigurationTest { .setLabels(LABELS) .setJobTimeoutMs(TIMEOUT) .build(); + private static final ExtractJobConfiguration EXTRACT_CONFIGURATION_MODEL = + ExtractJobConfiguration.newBuilder(MODEL_ID, DESTINATION_URIS) + .setPrintHeader(PRINT_HEADER) + .setFieldDelimiter(FIELD_DELIMITER) + .setCompression(COMPRESSION) + .setFormat(FORMAT) + .setUseAvroLogicalTypes(USEAVROLOGICALTYPES) + .setLabels(LABELS) + .setJobTimeoutMs(TIMEOUT) + .build(); @Test public void testToBuilder() { @@ -78,6 +89,14 @@ public void testToBuilder() { ExtractJobConfiguration job = EXTRACT_CONFIGURATION.toBuilder().setSourceTable(TableId.of("dataset", "newTable")).build(); assertEquals("newTable", job.getSourceTable().getTable()); + compareExtractJobConfiguration( + EXTRACT_CONFIGURATION_MODEL, EXTRACT_CONFIGURATION_MODEL.toBuilder().build()); + ExtractJobConfiguration modelJob = + EXTRACT_CONFIGURATION_MODEL + .toBuilder() + .setSourceModel(ModelId.of("dataset", "newModel")) + .build(); + assertEquals("newModel", modelJob.getSourceModel().getModel()); job = job.toBuilder().setSourceTable(TABLE_ID).build(); compareExtractJobConfiguration(EXTRACT_CONFIGURATION, job); compareExtractJobConfiguration( @@ -108,12 +127,28 @@ public void testOf() { assertEquals(TABLE_ID, job.getSourceTable()); assertEquals(ImmutableList.of(DESTINATION_URI), job.getDestinationUris()); assertEquals(JSON_FORMAT, job.getFormat()); + ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS); + assertEquals(MODEL_ID, modelJob.getSourceModel()); + assertEquals(DESTINATION_URIS, modelJob.getDestinationUris()); + modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URI); + assertEquals(MODEL_ID, modelJob.getSourceModel()); + assertEquals(ImmutableList.of(DESTINATION_URI), modelJob.getDestinationUris()); + modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS, JSON_FORMAT); + assertEquals(MODEL_ID, modelJob.getSourceModel()); + assertEquals(DESTINATION_URIS, modelJob.getDestinationUris()); + assertEquals(JSON_FORMAT, modelJob.getFormat()); + modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URI, JSON_FORMAT); + assertEquals(MODEL_ID, modelJob.getSourceModel()); + assertEquals(ImmutableList.of(DESTINATION_URI), modelJob.getDestinationUris()); + assertEquals(JSON_FORMAT, modelJob.getFormat()); } @Test public void testToBuilderIncomplete() { ExtractJobConfiguration job = ExtractJobConfiguration.of(TABLE_ID, DESTINATION_URIS); compareExtractJobConfiguration(job, job.toBuilder().build()); + ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS); + compareExtractJobConfiguration(modelJob, modelJob.toBuilder().build()); } @Test @@ -144,6 +179,14 @@ public void testBuilder() { assertEquals(USEAVROLOGICALTYPES, EXTRACT_CONFIGURATION_AVRO.getUseAvroLogicalTypes()); assertEquals(LABELS, EXTRACT_CONFIGURATION_AVRO.getLabels()); assertEquals(TIMEOUT, EXTRACT_CONFIGURATION_AVRO.getJobTimeoutMs()); + assertEquals(MODEL_ID, EXTRACT_CONFIGURATION_MODEL.getSourceModel()); + assertEquals(DESTINATION_URIS, EXTRACT_CONFIGURATION_MODEL.getDestinationUris()); + assertEquals(FIELD_DELIMITER, EXTRACT_CONFIGURATION_MODEL.getFieldDelimiter()); + assertEquals(COMPRESSION, EXTRACT_CONFIGURATION_MODEL.getCompression()); + assertEquals(PRINT_HEADER, EXTRACT_CONFIGURATION_MODEL.printHeader()); + assertEquals(FORMAT, EXTRACT_CONFIGURATION_MODEL.getFormat()); + assertEquals(LABELS, EXTRACT_CONFIGURATION_MODEL.getLabels()); + assertEquals(TIMEOUT, EXTRACT_CONFIGURATION_MODEL.getJobTimeoutMs()); } @Test @@ -164,12 +207,17 @@ public void testToPbAndFromPb() { ExtractJobConfiguration.fromPb(EXTRACT_CONFIGURATION_AVRO.toPb())); ExtractJobConfiguration job = ExtractJobConfiguration.of(TABLE_ID, DESTINATION_URIS); compareExtractJobConfiguration(job, ExtractJobConfiguration.fromPb(job.toPb())); + ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS); + compareExtractJobConfiguration(modelJob, ExtractJobConfiguration.fromPb(modelJob.toPb())); } @Test public void testSetProjectId() { ExtractJobConfiguration configuration = EXTRACT_CONFIGURATION.setProjectId(TEST_PROJECT_ID); assertEquals(TEST_PROJECT_ID, configuration.getSourceTable().getProject()); + ExtractJobConfiguration modelConfiguration = + EXTRACT_CONFIGURATION_MODEL.setProjectId(TEST_PROJECT_ID); + assertEquals(TEST_PROJECT_ID, modelConfiguration.getSourceModel().getProject()); } @Test @@ -181,6 +229,13 @@ public void testSetProjectIdDoNotOverride() { .build() .setProjectId("do-not-update"); assertEquals(TEST_PROJECT_ID, configuration.getSourceTable().getProject()); + ExtractJobConfiguration modelConfiguration = + EXTRACT_CONFIGURATION_MODEL + .toBuilder() + .setSourceModel(MODEL_ID.setProjectId(TEST_PROJECT_ID)) + .build() + .setProjectId("do-not-update"); + assertEquals(TEST_PROJECT_ID, modelConfiguration.getSourceModel().getProject()); } @Test @@ -188,6 +243,7 @@ public void testGetType() { assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION.getType()); assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_ONE_URI.getType()); assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_AVRO.getType()); + assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_MODEL.getType()); } private void compareExtractJobConfiguration( @@ -196,6 +252,7 @@ private void compareExtractJobConfiguration( assertEquals(expected.hashCode(), value.hashCode()); assertEquals(expected.toString(), value.toString()); assertEquals(expected.getSourceTable(), value.getSourceTable()); + assertEquals(expected.getSourceModel(), value.getSourceModel()); assertEquals(expected.getDestinationUris(), value.getDestinationUris()); assertEquals(expected.getCompression(), value.getCompression()); assertEquals(expected.printHeader(), value.printHeader()); diff --git a/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java b/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java index 165122d99..9e334c8a8 100644 --- a/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java +++ b/google-cloud-bigquery/src/test/java/com/google/cloud/bigquery/it/ITBigQueryTest.java @@ -254,6 +254,7 @@ public class ITBigQueryTest { private static final String LOAD_FILE = "load.csv"; private static final String JSON_LOAD_FILE = "load.json"; private static final String EXTRACT_FILE = "extract.csv"; + private static final String EXTRACT_MODEL_FILE = "extract_model.csv"; private static final String BUCKET = RemoteStorageHelper.generateBucketName(); private static final TableId TABLE_ID = TableId.of(DATASET, "testing_table"); private static final String CSV_CONTENT = "StringValue1\nStringValue2\n"; @@ -1945,6 +1946,43 @@ public void testExtractJob() throws InterruptedException, TimeoutException { assertTrue(bigquery.delete(destinationTable)); } + @Test + public void testExtractJobWithModel() throws InterruptedException { + String modelName = RemoteBigQueryHelper.generateModelName(); + String sql = + "CREATE MODEL `" + + MODEL_DATASET + + "." + + modelName + + "`" + + "OPTIONS ( " + + "model_type='linear_reg', " + + "max_iteration=1, " + + "learn_rate=0.4, " + + "learn_rate_strategy='constant' " + + ") AS ( " + + " SELECT 'a' AS f1, 2.0 AS label " + + "UNION ALL " + + "SELECT 'b' AS f1, 3.8 AS label " + + ")"; + + QueryJobConfiguration config = QueryJobConfiguration.newBuilder(sql).build(); + Job job = bigquery.create(JobInfo.of(JobId.of(), config)); + job.waitFor(); + assertNull(job.getStatus().getError()); + ModelId destinationModel = ModelId.of(MODEL_DATASET, modelName); + assertNotNull(destinationModel); + ExtractJobConfiguration extractConfiguration = + ExtractJobConfiguration.newBuilder( + destinationModel, "gs://" + BUCKET + "/" + EXTRACT_MODEL_FILE) + .setPrintHeader(false) + .build(); + Job remoteExtractJob = bigquery.create(JobInfo.of(extractConfiguration)); + remoteExtractJob = remoteExtractJob.waitFor(); + assertNull(remoteExtractJob.getStatus().getError()); + assertTrue(bigquery.delete(destinationModel)); + } + @Test public void testExtractJobWithLabels() throws InterruptedException, TimeoutException { String tableName = "test_export_job_table_label";