Skip to content

Commit

Permalink
feat: add extract model for extractjobconfiguration (#227)
Browse files Browse the repository at this point in the history
* feat: add extract model

* feat: modified code
  • Loading branch information
Praful Makani committed May 12, 2020
1 parent 510a80e commit 9e8cd76
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 4 deletions.
Expand Up @@ -38,6 +38,7 @@ public final class ExtractJobConfiguration extends JobConfiguration {
private static final long serialVersionUID = 4147749733166593761L;

private final TableId sourceTable;
private final ModelId sourceModel;
private final List<String> destinationUris;
private final Boolean printHeader;
private final String fieldDelimiter;
Expand All @@ -51,6 +52,7 @@ public static final class Builder
extends JobConfiguration.Builder<ExtractJobConfiguration, Builder> {

private TableId sourceTable;
private ModelId sourceModel;
private List<String> destinationUris;
private Boolean printHeader;
private String fieldDelimiter;
Expand All @@ -67,6 +69,7 @@ private Builder() {
private Builder(ExtractJobConfiguration jobInfo) {
this();
this.sourceTable = jobInfo.sourceTable;
this.sourceModel = jobInfo.sourceModel;
this.destinationUris = jobInfo.destinationUris;
this.printHeader = jobInfo.printHeader;
this.fieldDelimiter = jobInfo.fieldDelimiter;
Expand All @@ -80,7 +83,12 @@ private Builder(ExtractJobConfiguration jobInfo) {
private Builder(com.google.api.services.bigquery.model.JobConfiguration configurationPb) {
this();
JobConfigurationExtract extractConfigurationPb = configurationPb.getExtract();
this.sourceTable = TableId.fromPb(extractConfigurationPb.getSourceTable());
if (extractConfigurationPb.getSourceTable() != null) {
this.sourceTable = TableId.fromPb(extractConfigurationPb.getSourceTable());
}
if (extractConfigurationPb.getSourceModel() != null) {
this.sourceModel = ModelId.fromPb(extractConfigurationPb.getSourceModel());
}
this.destinationUris = extractConfigurationPb.getDestinationUris();
this.printHeader = extractConfigurationPb.getPrintHeader();
this.fieldDelimiter = extractConfigurationPb.getFieldDelimiter();
Expand All @@ -101,6 +109,12 @@ public Builder setSourceTable(TableId sourceTable) {
return this;
}

/** Sets the model to export. */
public Builder setSourceModel(ModelId sourceModel) {
this.sourceModel = sourceModel;
return this;
}

/**
* Sets the list of fully-qualified Google Cloud Storage URIs (e.g. gs://bucket/path) where the
* extracted table should be written.
Expand Down Expand Up @@ -191,7 +205,8 @@ public ExtractJobConfiguration build() {

private ExtractJobConfiguration(Builder builder) {
super(builder);
this.sourceTable = checkNotNull(builder.sourceTable);
this.sourceTable = builder.sourceTable;
this.sourceModel = builder.sourceModel;
this.destinationUris = checkNotNull(builder.destinationUris);
this.printHeader = builder.printHeader;
this.fieldDelimiter = builder.fieldDelimiter;
Expand All @@ -207,6 +222,11 @@ public TableId getSourceTable() {
return sourceTable;
}

/** Returns the model to export. */
public ModelId getSourceModel() {
return sourceModel;
}

/**
* Returns the list of fully-qualified Google Cloud Storage URIs where the extracted table should
* be written.
Expand Down Expand Up @@ -263,6 +283,7 @@ public Builder toBuilder() {
ToStringHelper toStringHelper() {
return super.toStringHelper()
.add("sourceTable", sourceTable)
.add("sourceModel", sourceModel)
.add("destinationUris", destinationUris)
.add("format", format)
.add("printHeader", printHeader)
Expand All @@ -284,6 +305,7 @@ public int hashCode() {
return Objects.hash(
baseHashCode(),
sourceTable,
sourceModel,
destinationUris,
printHeader,
fieldDelimiter,
Expand All @@ -296,9 +318,12 @@ public int hashCode() {

@Override
ExtractJobConfiguration setProjectId(String projectId) {
if (Strings.isNullOrEmpty(getSourceTable().getProject())) {
if (getSourceTable() != null && Strings.isNullOrEmpty(getSourceTable().getProject())) {
return toBuilder().setSourceTable(getSourceTable().setProjectId(projectId)).build();
}
if (getSourceModel() != null && Strings.isNullOrEmpty(getSourceModel().getProject())) {
return toBuilder().setSourceModel(getSourceModel().setProjectId(projectId)).build();
}
return this;
}

Expand All @@ -308,7 +333,12 @@ com.google.api.services.bigquery.model.JobConfiguration toPb() {
com.google.api.services.bigquery.model.JobConfiguration jobConfiguration =
new com.google.api.services.bigquery.model.JobConfiguration();
extractConfigurationPb.setDestinationUris(destinationUris);
extractConfigurationPb.setSourceTable(sourceTable.toPb());
if (sourceTable != null) {
extractConfigurationPb.setSourceTable(sourceTable.toPb());
}
if (sourceModel != null) {
extractConfigurationPb.setSourceModel(sourceModel.toPb());
}
extractConfigurationPb.setPrintHeader(printHeader);
extractConfigurationPb.setFieldDelimiter(fieldDelimiter);
extractConfigurationPb.setDestinationFormat(format);
Expand All @@ -333,6 +363,15 @@ public static Builder newBuilder(TableId sourceTable, String destinationUri) {
return newBuilder(sourceTable, ImmutableList.of(destinationUri));
}

/**
* Creates a builder for a BigQuery Extract Job configuration given source model and destination
* URI.
*/
public static Builder newBuilder(ModelId sourceModel, String destinationUri) {
checkArgument(!isNullOrEmpty(destinationUri), "Provided destinationUri is null or empty");
return newBuilder(sourceModel, ImmutableList.of(destinationUri));
}

/**
* Creates a builder for a BigQuery Extract Job configuration given source table and destination
* URIs.
Expand All @@ -341,20 +380,42 @@ public static Builder newBuilder(TableId sourceTable, List<String> destinationUr
return new Builder().setSourceTable(sourceTable).setDestinationUris(destinationUris);
}

/**
* Creates a builder for a BigQuery Extract Job configuration given source model and destination
* URIs.
*/
public static Builder newBuilder(ModelId sourceModel, List<String> destinationUris) {
return new Builder().setSourceModel(sourceModel).setDestinationUris(destinationUris);
}

/**
* Returns a BigQuery Extract Job configuration for the given source table and destination URI.
*/
public static ExtractJobConfiguration of(TableId sourceTable, String destinationUri) {
return newBuilder(sourceTable, destinationUri).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source model and destination URI.
*/
public static ExtractJobConfiguration of(ModelId sourceModel, String destinationUri) {
return newBuilder(sourceModel, destinationUri).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table and destination URIs.
*/
public static ExtractJobConfiguration of(TableId sourceTable, List<String> destinationUris) {
return newBuilder(sourceTable, destinationUris).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source model and destination URIs.
*/
public static ExtractJobConfiguration of(ModelId sourceModel, List<String> destinationUris) {
return newBuilder(sourceModel, destinationUris).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table, format and destination
* URI.
Expand All @@ -365,6 +426,16 @@ public static ExtractJobConfiguration of(
return newBuilder(sourceTable, destinationUri).setFormat(format).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source model, format and destination
* URI.
*/
public static ExtractJobConfiguration of(
ModelId sourceTable, String destinationUri, String format) {
checkArgument(!isNullOrEmpty(format), "Provided format is null or empty");
return newBuilder(sourceTable, destinationUri).setFormat(format).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table, format and destination
* URIs.
Expand All @@ -375,6 +446,16 @@ public static ExtractJobConfiguration of(
return newBuilder(sourceTable, destinationUris).setFormat(format).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table, format and destination
* URIs.
*/
public static ExtractJobConfiguration of(
ModelId sourceModel, List<String> destinationUris, String format) {
checkArgument(!isNullOrEmpty(format), "Provided format is null or empty");
return newBuilder(sourceModel, destinationUris).setFormat(format).build();
}

@SuppressWarnings("unchecked")
static ExtractJobConfiguration fromPb(
com.google.api.services.bigquery.model.JobConfiguration confPb) {
Expand Down
Expand Up @@ -32,6 +32,7 @@ public class ExtractJobConfigurationTest {
private static final List<String> DESTINATION_URIS = ImmutableList.of("uri1", "uri2");
private static final String DESTINATION_URI = "uri1";
private static final TableId TABLE_ID = TableId.of("dataset", "table");
private static final ModelId MODEL_ID = ModelId.of("dataset", "model");
private static final String FIELD_DELIMITER = ",";
private static final String FORMAT = "CSV";
private static final String AVRO_FORMAT = "AVRO";
Expand Down Expand Up @@ -70,6 +71,16 @@ public class ExtractJobConfigurationTest {
.setLabels(LABELS)
.setJobTimeoutMs(TIMEOUT)
.build();
private static final ExtractJobConfiguration EXTRACT_CONFIGURATION_MODEL =
ExtractJobConfiguration.newBuilder(MODEL_ID, DESTINATION_URIS)
.setPrintHeader(PRINT_HEADER)
.setFieldDelimiter(FIELD_DELIMITER)
.setCompression(COMPRESSION)
.setFormat(FORMAT)
.setUseAvroLogicalTypes(USEAVROLOGICALTYPES)
.setLabels(LABELS)
.setJobTimeoutMs(TIMEOUT)
.build();

@Test
public void testToBuilder() {
Expand All @@ -78,6 +89,14 @@ public void testToBuilder() {
ExtractJobConfiguration job =
EXTRACT_CONFIGURATION.toBuilder().setSourceTable(TableId.of("dataset", "newTable")).build();
assertEquals("newTable", job.getSourceTable().getTable());
compareExtractJobConfiguration(
EXTRACT_CONFIGURATION_MODEL, EXTRACT_CONFIGURATION_MODEL.toBuilder().build());
ExtractJobConfiguration modelJob =
EXTRACT_CONFIGURATION_MODEL
.toBuilder()
.setSourceModel(ModelId.of("dataset", "newModel"))
.build();
assertEquals("newModel", modelJob.getSourceModel().getModel());
job = job.toBuilder().setSourceTable(TABLE_ID).build();
compareExtractJobConfiguration(EXTRACT_CONFIGURATION, job);
compareExtractJobConfiguration(
Expand Down Expand Up @@ -108,12 +127,28 @@ public void testOf() {
assertEquals(TABLE_ID, job.getSourceTable());
assertEquals(ImmutableList.of(DESTINATION_URI), job.getDestinationUris());
assertEquals(JSON_FORMAT, job.getFormat());
ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(DESTINATION_URIS, modelJob.getDestinationUris());
modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URI);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(ImmutableList.of(DESTINATION_URI), modelJob.getDestinationUris());
modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS, JSON_FORMAT);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(DESTINATION_URIS, modelJob.getDestinationUris());
assertEquals(JSON_FORMAT, modelJob.getFormat());
modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URI, JSON_FORMAT);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(ImmutableList.of(DESTINATION_URI), modelJob.getDestinationUris());
assertEquals(JSON_FORMAT, modelJob.getFormat());
}

@Test
public void testToBuilderIncomplete() {
ExtractJobConfiguration job = ExtractJobConfiguration.of(TABLE_ID, DESTINATION_URIS);
compareExtractJobConfiguration(job, job.toBuilder().build());
ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS);
compareExtractJobConfiguration(modelJob, modelJob.toBuilder().build());
}

@Test
Expand Down Expand Up @@ -144,6 +179,14 @@ public void testBuilder() {
assertEquals(USEAVROLOGICALTYPES, EXTRACT_CONFIGURATION_AVRO.getUseAvroLogicalTypes());
assertEquals(LABELS, EXTRACT_CONFIGURATION_AVRO.getLabels());
assertEquals(TIMEOUT, EXTRACT_CONFIGURATION_AVRO.getJobTimeoutMs());
assertEquals(MODEL_ID, EXTRACT_CONFIGURATION_MODEL.getSourceModel());
assertEquals(DESTINATION_URIS, EXTRACT_CONFIGURATION_MODEL.getDestinationUris());
assertEquals(FIELD_DELIMITER, EXTRACT_CONFIGURATION_MODEL.getFieldDelimiter());
assertEquals(COMPRESSION, EXTRACT_CONFIGURATION_MODEL.getCompression());
assertEquals(PRINT_HEADER, EXTRACT_CONFIGURATION_MODEL.printHeader());
assertEquals(FORMAT, EXTRACT_CONFIGURATION_MODEL.getFormat());
assertEquals(LABELS, EXTRACT_CONFIGURATION_MODEL.getLabels());
assertEquals(TIMEOUT, EXTRACT_CONFIGURATION_MODEL.getJobTimeoutMs());
}

@Test
Expand All @@ -164,12 +207,17 @@ public void testToPbAndFromPb() {
ExtractJobConfiguration.fromPb(EXTRACT_CONFIGURATION_AVRO.toPb()));
ExtractJobConfiguration job = ExtractJobConfiguration.of(TABLE_ID, DESTINATION_URIS);
compareExtractJobConfiguration(job, ExtractJobConfiguration.fromPb(job.toPb()));
ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS);
compareExtractJobConfiguration(modelJob, ExtractJobConfiguration.fromPb(modelJob.toPb()));
}

@Test
public void testSetProjectId() {
ExtractJobConfiguration configuration = EXTRACT_CONFIGURATION.setProjectId(TEST_PROJECT_ID);
assertEquals(TEST_PROJECT_ID, configuration.getSourceTable().getProject());
ExtractJobConfiguration modelConfiguration =
EXTRACT_CONFIGURATION_MODEL.setProjectId(TEST_PROJECT_ID);
assertEquals(TEST_PROJECT_ID, modelConfiguration.getSourceModel().getProject());
}

@Test
Expand All @@ -181,13 +229,21 @@ public void testSetProjectIdDoNotOverride() {
.build()
.setProjectId("do-not-update");
assertEquals(TEST_PROJECT_ID, configuration.getSourceTable().getProject());
ExtractJobConfiguration modelConfiguration =
EXTRACT_CONFIGURATION_MODEL
.toBuilder()
.setSourceModel(MODEL_ID.setProjectId(TEST_PROJECT_ID))
.build()
.setProjectId("do-not-update");
assertEquals(TEST_PROJECT_ID, modelConfiguration.getSourceModel().getProject());
}

@Test
public void testGetType() {
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION.getType());
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_ONE_URI.getType());
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_AVRO.getType());
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_MODEL.getType());
}

private void compareExtractJobConfiguration(
Expand All @@ -196,6 +252,7 @@ private void compareExtractJobConfiguration(
assertEquals(expected.hashCode(), value.hashCode());
assertEquals(expected.toString(), value.toString());
assertEquals(expected.getSourceTable(), value.getSourceTable());
assertEquals(expected.getSourceModel(), value.getSourceModel());
assertEquals(expected.getDestinationUris(), value.getDestinationUris());
assertEquals(expected.getCompression(), value.getCompression());
assertEquals(expected.printHeader(), value.printHeader());
Expand Down
Expand Up @@ -254,6 +254,7 @@ public class ITBigQueryTest {
private static final String LOAD_FILE = "load.csv";
private static final String JSON_LOAD_FILE = "load.json";
private static final String EXTRACT_FILE = "extract.csv";
private static final String EXTRACT_MODEL_FILE = "extract_model.csv";
private static final String BUCKET = RemoteStorageHelper.generateBucketName();
private static final TableId TABLE_ID = TableId.of(DATASET, "testing_table");
private static final String CSV_CONTENT = "StringValue1\nStringValue2\n";
Expand Down Expand Up @@ -1945,6 +1946,43 @@ public void testExtractJob() throws InterruptedException, TimeoutException {
assertTrue(bigquery.delete(destinationTable));
}

@Test
public void testExtractJobWithModel() throws InterruptedException {
String modelName = RemoteBigQueryHelper.generateModelName();
String sql =
"CREATE MODEL `"
+ MODEL_DATASET
+ "."
+ modelName
+ "`"
+ "OPTIONS ( "
+ "model_type='linear_reg', "
+ "max_iteration=1, "
+ "learn_rate=0.4, "
+ "learn_rate_strategy='constant' "
+ ") AS ( "
+ " SELECT 'a' AS f1, 2.0 AS label "
+ "UNION ALL "
+ "SELECT 'b' AS f1, 3.8 AS label "
+ ")";

QueryJobConfiguration config = QueryJobConfiguration.newBuilder(sql).build();
Job job = bigquery.create(JobInfo.of(JobId.of(), config));
job.waitFor();
assertNull(job.getStatus().getError());
ModelId destinationModel = ModelId.of(MODEL_DATASET, modelName);
assertNotNull(destinationModel);
ExtractJobConfiguration extractConfiguration =
ExtractJobConfiguration.newBuilder(
destinationModel, "gs://" + BUCKET + "/" + EXTRACT_MODEL_FILE)
.setPrintHeader(false)
.build();
Job remoteExtractJob = bigquery.create(JobInfo.of(extractConfiguration));
remoteExtractJob = remoteExtractJob.waitFor();
assertNull(remoteExtractJob.getStatus().getError());
assertTrue(bigquery.delete(destinationModel));
}

@Test
public void testExtractJobWithLabels() throws InterruptedException, TimeoutException {
String tableName = "test_export_job_table_label";
Expand Down

0 comments on commit 9e8cd76

Please sign in to comment.