Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add extract model for extractjobconfiguration #227

Merged
merged 2 commits into from May 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -38,6 +38,7 @@ public final class ExtractJobConfiguration extends JobConfiguration {
private static final long serialVersionUID = 4147749733166593761L;

private final TableId sourceTable;
private final ModelId sourceModel;
private final List<String> destinationUris;
private final Boolean printHeader;
private final String fieldDelimiter;
Expand All @@ -51,6 +52,7 @@ public static final class Builder
extends JobConfiguration.Builder<ExtractJobConfiguration, Builder> {

private TableId sourceTable;
private ModelId sourceModel;
private List<String> destinationUris;
private Boolean printHeader;
private String fieldDelimiter;
Expand All @@ -67,6 +69,7 @@ private Builder() {
private Builder(ExtractJobConfiguration jobInfo) {
this();
this.sourceTable = jobInfo.sourceTable;
this.sourceModel = jobInfo.sourceModel;
this.destinationUris = jobInfo.destinationUris;
this.printHeader = jobInfo.printHeader;
this.fieldDelimiter = jobInfo.fieldDelimiter;
Expand All @@ -80,7 +83,12 @@ private Builder(ExtractJobConfiguration jobInfo) {
private Builder(com.google.api.services.bigquery.model.JobConfiguration configurationPb) {
this();
JobConfigurationExtract extractConfigurationPb = configurationPb.getExtract();
this.sourceTable = TableId.fromPb(extractConfigurationPb.getSourceTable());
if (extractConfigurationPb.getSourceTable() != null) {
this.sourceTable = TableId.fromPb(extractConfigurationPb.getSourceTable());
}
if (extractConfigurationPb.getSourceModel() != null) {
this.sourceModel = ModelId.fromPb(extractConfigurationPb.getSourceModel());
}
this.destinationUris = extractConfigurationPb.getDestinationUris();
this.printHeader = extractConfigurationPb.getPrintHeader();
this.fieldDelimiter = extractConfigurationPb.getFieldDelimiter();
Expand All @@ -101,6 +109,12 @@ public Builder setSourceTable(TableId sourceTable) {
return this;
}

/** Sets the model to export. */
public Builder setSourceModel(ModelId sourceModel) {
this.sourceModel = sourceModel;
return this;
}

/**
* Sets the list of fully-qualified Google Cloud Storage URIs (e.g. gs://bucket/path) where the
* extracted table should be written.
Expand Down Expand Up @@ -191,7 +205,8 @@ public ExtractJobConfiguration build() {

private ExtractJobConfiguration(Builder builder) {
super(builder);
this.sourceTable = checkNotNull(builder.sourceTable);
this.sourceTable = builder.sourceTable;
this.sourceModel = builder.sourceModel;
this.destinationUris = checkNotNull(builder.destinationUris);
this.printHeader = builder.printHeader;
this.fieldDelimiter = builder.fieldDelimiter;
Expand All @@ -207,6 +222,11 @@ public TableId getSourceTable() {
return sourceTable;
}

/** Returns the model to export. */
public ModelId getSourceModel() {
return sourceModel;
}

/**
* Returns the list of fully-qualified Google Cloud Storage URIs where the extracted table should
* be written.
Expand Down Expand Up @@ -263,6 +283,7 @@ public Builder toBuilder() {
ToStringHelper toStringHelper() {
return super.toStringHelper()
.add("sourceTable", sourceTable)
.add("sourceModel", sourceModel)
.add("destinationUris", destinationUris)
.add("format", format)
.add("printHeader", printHeader)
Expand All @@ -284,6 +305,7 @@ public int hashCode() {
return Objects.hash(
baseHashCode(),
sourceTable,
sourceModel,
destinationUris,
printHeader,
fieldDelimiter,
Expand All @@ -296,9 +318,12 @@ public int hashCode() {

@Override
ExtractJobConfiguration setProjectId(String projectId) {
if (Strings.isNullOrEmpty(getSourceTable().getProject())) {
if (getSourceTable() != null && Strings.isNullOrEmpty(getSourceTable().getProject())) {
return toBuilder().setSourceTable(getSourceTable().setProjectId(projectId)).build();
}
if (getSourceModel() != null && Strings.isNullOrEmpty(getSourceModel().getProject())) {
return toBuilder().setSourceModel(getSourceModel().setProjectId(projectId)).build();
}
return this;
}

Expand All @@ -308,7 +333,12 @@ com.google.api.services.bigquery.model.JobConfiguration toPb() {
com.google.api.services.bigquery.model.JobConfiguration jobConfiguration =
new com.google.api.services.bigquery.model.JobConfiguration();
extractConfigurationPb.setDestinationUris(destinationUris);
extractConfigurationPb.setSourceTable(sourceTable.toPb());
if (sourceTable != null) {
extractConfigurationPb.setSourceTable(sourceTable.toPb());
}
if (sourceModel != null) {
extractConfigurationPb.setSourceModel(sourceModel.toPb());
}
extractConfigurationPb.setPrintHeader(printHeader);
extractConfigurationPb.setFieldDelimiter(fieldDelimiter);
extractConfigurationPb.setDestinationFormat(format);
Expand All @@ -333,6 +363,15 @@ public static Builder newBuilder(TableId sourceTable, String destinationUri) {
return newBuilder(sourceTable, ImmutableList.of(destinationUri));
}

/**
* Creates a builder for a BigQuery Extract Job configuration given source model and destination
* URI.
*/
public static Builder newBuilder(ModelId sourceModel, String destinationUri) {
checkArgument(!isNullOrEmpty(destinationUri), "Provided destinationUri is null or empty");
return newBuilder(sourceModel, ImmutableList.of(destinationUri));
}

/**
* Creates a builder for a BigQuery Extract Job configuration given source table and destination
* URIs.
Expand All @@ -341,20 +380,42 @@ public static Builder newBuilder(TableId sourceTable, List<String> destinationUr
return new Builder().setSourceTable(sourceTable).setDestinationUris(destinationUris);
}

/**
* Creates a builder for a BigQuery Extract Job configuration given source model and destination
* URIs.
*/
public static Builder newBuilder(ModelId sourceModel, List<String> destinationUris) {
return new Builder().setSourceModel(sourceModel).setDestinationUris(destinationUris);
}

/**
* Returns a BigQuery Extract Job configuration for the given source table and destination URI.
*/
public static ExtractJobConfiguration of(TableId sourceTable, String destinationUri) {
return newBuilder(sourceTable, destinationUri).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source model and destination URI.
*/
public static ExtractJobConfiguration of(ModelId sourceModel, String destinationUri) {
return newBuilder(sourceModel, destinationUri).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table and destination URIs.
*/
public static ExtractJobConfiguration of(TableId sourceTable, List<String> destinationUris) {
return newBuilder(sourceTable, destinationUris).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source model and destination URIs.
*/
public static ExtractJobConfiguration of(ModelId sourceModel, List<String> destinationUris) {
return newBuilder(sourceModel, destinationUris).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table, format and destination
* URI.
Expand All @@ -365,6 +426,16 @@ public static ExtractJobConfiguration of(
return newBuilder(sourceTable, destinationUri).setFormat(format).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source model, format and destination
* URI.
*/
public static ExtractJobConfiguration of(
ModelId sourceTable, String destinationUri, String format) {
checkArgument(!isNullOrEmpty(format), "Provided format is null or empty");
return newBuilder(sourceTable, destinationUri).setFormat(format).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table, format and destination
* URIs.
Expand All @@ -375,6 +446,16 @@ public static ExtractJobConfiguration of(
return newBuilder(sourceTable, destinationUris).setFormat(format).build();
}

/**
* Returns a BigQuery Extract Job configuration for the given source table, format and destination
* URIs.
*/
public static ExtractJobConfiguration of(
ModelId sourceModel, List<String> destinationUris, String format) {
checkArgument(!isNullOrEmpty(format), "Provided format is null or empty");
return newBuilder(sourceModel, destinationUris).setFormat(format).build();
}

@SuppressWarnings("unchecked")
static ExtractJobConfiguration fromPb(
com.google.api.services.bigquery.model.JobConfiguration confPb) {
Expand Down
Expand Up @@ -32,6 +32,7 @@ public class ExtractJobConfigurationTest {
private static final List<String> DESTINATION_URIS = ImmutableList.of("uri1", "uri2");
private static final String DESTINATION_URI = "uri1";
private static final TableId TABLE_ID = TableId.of("dataset", "table");
private static final ModelId MODEL_ID = ModelId.of("dataset", "model");
private static final String FIELD_DELIMITER = ",";
private static final String FORMAT = "CSV";
private static final String AVRO_FORMAT = "AVRO";
Expand Down Expand Up @@ -70,6 +71,16 @@ public class ExtractJobConfigurationTest {
.setLabels(LABELS)
.setJobTimeoutMs(TIMEOUT)
.build();
private static final ExtractJobConfiguration EXTRACT_CONFIGURATION_MODEL =
ExtractJobConfiguration.newBuilder(MODEL_ID, DESTINATION_URIS)
.setPrintHeader(PRINT_HEADER)
.setFieldDelimiter(FIELD_DELIMITER)
.setCompression(COMPRESSION)
.setFormat(FORMAT)
.setUseAvroLogicalTypes(USEAVROLOGICALTYPES)
.setLabels(LABELS)
.setJobTimeoutMs(TIMEOUT)
.build();

@Test
public void testToBuilder() {
Expand All @@ -78,6 +89,14 @@ public void testToBuilder() {
ExtractJobConfiguration job =
EXTRACT_CONFIGURATION.toBuilder().setSourceTable(TableId.of("dataset", "newTable")).build();
assertEquals("newTable", job.getSourceTable().getTable());
compareExtractJobConfiguration(
EXTRACT_CONFIGURATION_MODEL, EXTRACT_CONFIGURATION_MODEL.toBuilder().build());
ExtractJobConfiguration modelJob =
EXTRACT_CONFIGURATION_MODEL
.toBuilder()
.setSourceModel(ModelId.of("dataset", "newModel"))
.build();
assertEquals("newModel", modelJob.getSourceModel().getModel());
job = job.toBuilder().setSourceTable(TABLE_ID).build();
compareExtractJobConfiguration(EXTRACT_CONFIGURATION, job);
compareExtractJobConfiguration(
Expand Down Expand Up @@ -108,12 +127,28 @@ public void testOf() {
assertEquals(TABLE_ID, job.getSourceTable());
assertEquals(ImmutableList.of(DESTINATION_URI), job.getDestinationUris());
assertEquals(JSON_FORMAT, job.getFormat());
ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(DESTINATION_URIS, modelJob.getDestinationUris());
modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URI);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(ImmutableList.of(DESTINATION_URI), modelJob.getDestinationUris());
modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS, JSON_FORMAT);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(DESTINATION_URIS, modelJob.getDestinationUris());
assertEquals(JSON_FORMAT, modelJob.getFormat());
modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URI, JSON_FORMAT);
assertEquals(MODEL_ID, modelJob.getSourceModel());
assertEquals(ImmutableList.of(DESTINATION_URI), modelJob.getDestinationUris());
assertEquals(JSON_FORMAT, modelJob.getFormat());
}

@Test
public void testToBuilderIncomplete() {
ExtractJobConfiguration job = ExtractJobConfiguration.of(TABLE_ID, DESTINATION_URIS);
compareExtractJobConfiguration(job, job.toBuilder().build());
ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS);
compareExtractJobConfiguration(modelJob, modelJob.toBuilder().build());
}

@Test
Expand Down Expand Up @@ -144,6 +179,14 @@ public void testBuilder() {
assertEquals(USEAVROLOGICALTYPES, EXTRACT_CONFIGURATION_AVRO.getUseAvroLogicalTypes());
assertEquals(LABELS, EXTRACT_CONFIGURATION_AVRO.getLabels());
assertEquals(TIMEOUT, EXTRACT_CONFIGURATION_AVRO.getJobTimeoutMs());
assertEquals(MODEL_ID, EXTRACT_CONFIGURATION_MODEL.getSourceModel());
assertEquals(DESTINATION_URIS, EXTRACT_CONFIGURATION_MODEL.getDestinationUris());
assertEquals(FIELD_DELIMITER, EXTRACT_CONFIGURATION_MODEL.getFieldDelimiter());
assertEquals(COMPRESSION, EXTRACT_CONFIGURATION_MODEL.getCompression());
assertEquals(PRINT_HEADER, EXTRACT_CONFIGURATION_MODEL.printHeader());
assertEquals(FORMAT, EXTRACT_CONFIGURATION_MODEL.getFormat());
assertEquals(LABELS, EXTRACT_CONFIGURATION_MODEL.getLabels());
assertEquals(TIMEOUT, EXTRACT_CONFIGURATION_MODEL.getJobTimeoutMs());
}

@Test
Expand All @@ -164,12 +207,17 @@ public void testToPbAndFromPb() {
ExtractJobConfiguration.fromPb(EXTRACT_CONFIGURATION_AVRO.toPb()));
ExtractJobConfiguration job = ExtractJobConfiguration.of(TABLE_ID, DESTINATION_URIS);
compareExtractJobConfiguration(job, ExtractJobConfiguration.fromPb(job.toPb()));
ExtractJobConfiguration modelJob = ExtractJobConfiguration.of(MODEL_ID, DESTINATION_URIS);
compareExtractJobConfiguration(modelJob, ExtractJobConfiguration.fromPb(modelJob.toPb()));
}

@Test
public void testSetProjectId() {
ExtractJobConfiguration configuration = EXTRACT_CONFIGURATION.setProjectId(TEST_PROJECT_ID);
assertEquals(TEST_PROJECT_ID, configuration.getSourceTable().getProject());
ExtractJobConfiguration modelConfiguration =
EXTRACT_CONFIGURATION_MODEL.setProjectId(TEST_PROJECT_ID);
assertEquals(TEST_PROJECT_ID, modelConfiguration.getSourceModel().getProject());
}

@Test
Expand All @@ -181,13 +229,21 @@ public void testSetProjectIdDoNotOverride() {
.build()
.setProjectId("do-not-update");
assertEquals(TEST_PROJECT_ID, configuration.getSourceTable().getProject());
ExtractJobConfiguration modelConfiguration =
EXTRACT_CONFIGURATION_MODEL
.toBuilder()
.setSourceModel(MODEL_ID.setProjectId(TEST_PROJECT_ID))
.build()
.setProjectId("do-not-update");
assertEquals(TEST_PROJECT_ID, modelConfiguration.getSourceModel().getProject());
}

@Test
public void testGetType() {
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION.getType());
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_ONE_URI.getType());
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_AVRO.getType());
assertEquals(JobConfiguration.Type.EXTRACT, EXTRACT_CONFIGURATION_MODEL.getType());
}

private void compareExtractJobConfiguration(
Expand All @@ -196,6 +252,7 @@ private void compareExtractJobConfiguration(
assertEquals(expected.hashCode(), value.hashCode());
assertEquals(expected.toString(), value.toString());
assertEquals(expected.getSourceTable(), value.getSourceTable());
assertEquals(expected.getSourceModel(), value.getSourceModel());
assertEquals(expected.getDestinationUris(), value.getDestinationUris());
assertEquals(expected.getCompression(), value.getCompression());
assertEquals(expected.printHeader(), value.printHeader());
Expand Down
Expand Up @@ -226,6 +226,7 @@ public class ITBigQueryTest {
private static final String LOAD_FILE = "load.csv";
private static final String JSON_LOAD_FILE = "load.json";
private static final String EXTRACT_FILE = "extract.csv";
private static final String EXTRACT_MODEL_FILE = "extract_model.csv";
private static final String BUCKET = RemoteStorageHelper.generateBucketName();
private static final TableId TABLE_ID = TableId.of(DATASET, "testing_table");
private static final String CSV_CONTENT = "StringValue1\nStringValue2\n";
Expand Down Expand Up @@ -1790,6 +1791,43 @@ public void testExtractJob() throws InterruptedException, TimeoutException {
assertTrue(bigquery.delete(destinationTable));
}

@Test
public void testExtractJobWithModel() throws InterruptedException {
String modelName = RemoteBigQueryHelper.generateModelName();
String sql =
"CREATE MODEL `"
+ MODEL_DATASET
+ "."
+ modelName
+ "`"
+ "OPTIONS ( "
+ "model_type='linear_reg', "
+ "max_iteration=1, "
+ "learn_rate=0.4, "
+ "learn_rate_strategy='constant' "
+ ") AS ( "
+ " SELECT 'a' AS f1, 2.0 AS label "
+ "UNION ALL "
+ "SELECT 'b' AS f1, 3.8 AS label "
+ ")";

QueryJobConfiguration config = QueryJobConfiguration.newBuilder(sql).build();
Job job = bigquery.create(JobInfo.of(JobId.of(), config));
job.waitFor();
assertNull(job.getStatus().getError());
ModelId destinationModel = ModelId.of(MODEL_DATASET, modelName);
assertNotNull(destinationModel);
ExtractJobConfiguration extractConfiguration =
ExtractJobConfiguration.newBuilder(
destinationModel, "gs://" + BUCKET + "/" + EXTRACT_MODEL_FILE)
.setPrintHeader(false)
.build();
Job remoteExtractJob = bigquery.create(JobInfo.of(extractConfiguration));
remoteExtractJob = remoteExtractJob.waitFor();
assertNull(remoteExtractJob.getStatus().getError());
assertTrue(bigquery.delete(destinationModel));
}

@Test
public void testExtractJobWithLabels() throws InterruptedException, TimeoutException {
String tableName = "test_export_job_table_label";
Expand Down